From e81b3c0cd1f18c245f0bef9c07c6f1e897dbefbb Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Wed, 11 Sep 2024 15:00:54 -0500 Subject: [PATCH 1/8] testing workflow --- .github/workflows/python-package-shared.yml | 7 +++++-- libs/labelbox/pyproject.toml | 7 ++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/.github/workflows/python-package-shared.yml b/.github/workflows/python-package-shared.yml index 4311020d8..acd30b299 100644 --- a/.github/workflows/python-package-shared.yml +++ b/.github/workflows/python-package-shared.yml @@ -18,7 +18,7 @@ on: test-env: required: true type: string - fixture-profile: + fixture-profile: required: true type: boolean @@ -36,6 +36,9 @@ jobs: - name: Linting working-directory: libs/labelbox run: rye run lint + - name: Format + working-directory: libs/labelbox + run: rye fmt --check integration: runs-on: ubuntu-latest concurrency: @@ -78,4 +81,4 @@ jobs: run: | rye sync -f --features labelbox/data rye run unit -n 32 - rye run data -n 32 \ No newline at end of file + rye run data -n 32 diff --git a/libs/labelbox/pyproject.toml b/libs/labelbox/pyproject.toml index ac167cdcb..771117a01 100644 --- a/libs/labelbox/pyproject.toml +++ b/libs/labelbox/pyproject.toml @@ -64,7 +64,6 @@ build-backend = "hatchling.build" [tool.rye] managed = true dev-dependencies = [ - "yapf>=0.40.2", "mypy>=1.9.0", "types-pillow>=10.2.0.20240311", "types-python-dateutil>=2.9.0.20240316", @@ -72,6 +71,9 @@ dev-dependencies = [ "types-tqdm>=4.66.0.20240106", ] +[tool.ruff] +line-length = 80 + [tool.rye.scripts] unit = "pytest tests/unit" # https://github.com/Labelbox/labelbox-python/blob/7c84fdffbc14fd1f69d2a6abdcc0087dc557fa4e/Makefile @@ -87,9 +89,8 @@ unit = "pytest tests/unit" # LABELBOX_TEST_BASE_URL="http://host.docker.internal:8080" \ integration = { cmd = "pytest tests/integration" } data = { cmd = "pytest tests/data" } -yapf-lint = "yapf tests src -i --verbose --recursive --parallel --style \"google\"" mypy-lint = "mypy src --pretty --show-error-codes --non-interactive --install-types" -lint = { chain = ["yapf-lint", "mypy-lint"] } +lint = { chain = ["mypy-lint"] } test = { chain = ["lint", "unit", "integration"] } [tool.hatch.metadata] From 2f966b0c28833dc1396e75355bb55f23151e49c4 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Wed, 11 Sep 2024 15:05:11 -0500 Subject: [PATCH 2/8] reformatted --- libs/labelbox/src/labelbox/__init__.py | 43 +- libs/labelbox/src/labelbox/adv_client.py | 62 +- libs/labelbox/src/labelbox/client.py | 1192 ++++++++----- .../data/annotation_types/__init__.py | 9 +- .../data/annotation_types/annotation.py | 4 +- .../data/annotation_types/base_annotation.py | 6 +- .../classification/__init__.py | 3 +- .../classification/classification.py | 19 +- .../data/annotation_types/collection.py | 43 +- .../data/annotation_types/data/__init__.py | 2 +- .../data/annotation_types/data/audio.py | 2 +- .../data/annotation_types/data/base_data.py | 1 + .../data/annotation_types/data/dicom.py | 2 +- .../data/annotation_types/data/document.py | 2 +- .../data/generic_data_row_data.py | 6 +- .../data/annotation_types/data/html.py | 2 +- .../data/llm_prompt_creation.py | 2 +- .../data/llm_prompt_response_creation.py | 5 +- .../data/llm_response_creation.py | 2 +- .../data/annotation_types/data/raster.py | 55 +- .../data/annotation_types/data/text.py | 25 +- .../data/annotation_types/data/tiled_image.py | 330 ++-- .../data/annotation_types/data/video.py | 41 +- .../labelbox/data/annotation_types/feature.py | 1 + .../annotation_types/geometry/geometry.py | 41 +- .../data/annotation_types/geometry/line.py | 43 +- .../data/annotation_types/geometry/mask.py | 42 +- .../data/annotation_types/geometry/point.py | 31 +- .../data/annotation_types/geometry/polygon.py | 31 +- .../annotation_types/geometry/rectangle.py | 54 +- .../labelbox/data/annotation_types/label.py | 121 +- .../llm_prompt_response/__init__.py | 2 +- .../llm_prompt_response/prompt.py | 12 +- .../data/annotation_types/metrics/__init__.py | 6 +- .../data/annotation_types/metrics/base.py | 20 +- .../metrics/confusion_matrix.py | 17 +- .../data/annotation_types/metrics/scalar.py | 29 +- .../src/labelbox/data/annotation_types/mmc.py | 11 +- .../ner/conversation_entity.py | 5 +- .../annotation_types/ner/document_entity.py | 3 +- .../data/annotation_types/ner/text_entity.py | 11 +- .../data/annotation_types/relationship.py | 6 +- .../labelbox/data/annotation_types/types.py | 19 +- .../labelbox/data/annotation_types/video.py | 105 +- libs/labelbox/src/labelbox/data/generator.py | 7 +- .../src/labelbox/data/metrics/__init__.py | 5 +- .../metrics/confusion_matrix/calculation.py | 171 +- .../confusion_matrix/confusion_matrix.py | 76 +- .../src/labelbox/data/metrics/group.py | 71 +- .../labelbox/data/metrics/iou/calculation.py | 199 ++- .../src/labelbox/data/metrics/iou/iou.py | 53 +- libs/labelbox/src/labelbox/data/mixins.py | 13 +- libs/labelbox/src/labelbox/data/ontology.py | 70 +- .../data/serialization/coco/annotation.py | 19 +- .../data/serialization/coco/categories.py | 3 +- .../data/serialization/coco/converter.py | 78 +- .../labelbox/data/serialization/coco/image.py | 2 +- .../serialization/coco/instance_dataset.py | 188 +- .../serialization/coco/panoptic_dataset.py | 149 +- .../labelbox/data/serialization/coco/path.py | 2 +- .../data/serialization/ndjson/base.py | 12 +- .../serialization/ndjson/classification.py | 436 +++-- .../data/serialization/ndjson/converter.py | 50 +- .../data/serialization/ndjson/label.py | 275 ++- .../data/serialization/ndjson/metric.py | 110 +- .../labelbox/data/serialization/ndjson/mmc.py | 36 +- .../data/serialization/ndjson/objects.py | 871 +++++---- .../data/serialization/ndjson/relationship.py | 49 +- libs/labelbox/src/labelbox/exceptions.py | 51 +- libs/labelbox/src/labelbox/orm/comparison.py | 46 +- libs/labelbox/src/labelbox/orm/db_object.py | 122 +- libs/labelbox/src/labelbox/orm/model.py | 95 +- libs/labelbox/src/labelbox/orm/query.py | 214 ++- libs/labelbox/src/labelbox/pagination.py | 80 +- libs/labelbox/src/labelbox/parser.py | 7 +- libs/labelbox/src/labelbox/schema/__init__.py | 2 +- .../src/labelbox/schema/annotation_import.py | 395 +++-- .../src/labelbox/schema/asset_attachment.py | 31 +- libs/labelbox/src/labelbox/schema/batch.py | 117 +- .../labelbox/src/labelbox/schema/benchmark.py | 14 +- .../labelbox/schema/bulk_import_request.py | 503 +++--- libs/labelbox/src/labelbox/schema/catalog.py | 135 +- .../schema/confidence_presence_checker.py | 5 +- .../labelbox/schema/create_batches_task.py | 6 +- libs/labelbox/src/labelbox/schema/data_row.py | 247 +-- .../src/labelbox/schema/data_row_metadata.py | 449 +++-- libs/labelbox/src/labelbox/schema/dataset.py | 403 +++-- .../labelbox/src/labelbox/schema/embedding.py | 11 +- libs/labelbox/src/labelbox/schema/enums.py | 29 +- .../src/labelbox/schema/export_filters.py | 142 +- .../src/labelbox/schema/export_params.py | 11 +- .../src/labelbox/schema/export_task.py | 246 ++- .../src/labelbox/schema/foundry/app.py | 4 +- .../labelbox/schema/foundry/foundry_client.py | 28 +- .../src/labelbox/schema/foundry/model.py | 2 +- .../src/labelbox/schema/iam_integration.py | 6 +- libs/labelbox/src/labelbox/schema/id_type.py | 3 +- .../src/labelbox/schema/identifiables.py | 6 +- .../schema/internal/data_row_uploader.py | 31 +- .../schema/internal/data_row_upsert_item.py | 49 +- .../internal/descriptor_file_creator.py | 108 +- libs/labelbox/src/labelbox/schema/invite.py | 9 +- libs/labelbox/src/labelbox/schema/label.py | 14 +- .../src/labelbox/schema/labeling_frontend.py | 6 +- .../src/labelbox/schema/labeling_service.py | 26 +- .../schema/labeling_service_dashboard.py | 62 +- .../schema/labeling_service_status.py | 18 +- .../src/labelbox/schema/media_type.py | 31 +- libs/labelbox/src/labelbox/schema/model.py | 33 +- .../src/labelbox/schema/model_config.py | 4 +- .../labelbox/src/labelbox/schema/model_run.py | 467 ++--- libs/labelbox/src/labelbox/schema/ontology.py | 292 +-- .../src/labelbox/schema/ontology_kind.py | 76 +- .../src/labelbox/schema/organization.py | 120 +- libs/labelbox/src/labelbox/schema/project.py | 889 ++++++---- .../labelbox/schema/project_model_config.py | 16 +- .../src/labelbox/schema/project_overview.py | 17 +- .../labelbox/schema/project_resource_tag.py | 2 +- .../src/labelbox/schema/resource_tag.py | 2 +- libs/labelbox/src/labelbox/schema/review.py | 6 +- libs/labelbox/src/labelbox/schema/role.py | 14 +- .../src/labelbox/schema/search_filters.py | 147 +- .../schema/send_to_annotate_params.py | 61 +- .../src/labelbox/schema/serialization.py | 5 +- libs/labelbox/src/labelbox/schema/slice.py | 182 +- libs/labelbox/src/labelbox/schema/task.py | 167 +- libs/labelbox/src/labelbox/schema/user.py | 40 +- .../src/labelbox/schema/user_group.py | 96 +- libs/labelbox/src/labelbox/schema/webhook.py | 37 +- libs/labelbox/src/labelbox/types.py | 2 +- libs/labelbox/src/labelbox/typing_imports.py | 5 +- libs/labelbox/src/labelbox/utils.py | 31 +- libs/labelbox/tests/conftest.py | 604 ++++--- .../tests/data/annotation_import/conftest.py | 1580 ++++++++--------- .../test_annotation_import_limit.py | 57 +- .../test_bulk_import_request.py | 143 +- .../data/annotation_import/test_data_types.py | 24 +- .../test_generic_data_types.py | 233 ++- .../annotation_import/test_label_import.py | 108 +- .../test_mal_prediction_import.py | 49 +- .../test_mea_prediction_import.py | 227 ++- .../data/annotation_import/test_model_run.py | 87 +- .../test_ndjson_validation.py | 157 +- .../test_send_to_annotate_mea.py | 44 +- .../test_upsert_prediction_import.py | 101 +- .../classification/test_classification.py | 190 +- .../data/annotation_types/data/test_raster.py | 12 +- .../data/annotation_types/data/test_text.py | 20 +- .../data/annotation_types/data/test_video.py | 14 +- .../annotation_types/geometry/test_line.py | 2 +- .../annotation_types/geometry/test_mask.py | 143 +- .../annotation_types/geometry/test_point.py | 2 +- .../annotation_types/geometry/test_polygon.py | 8 +- .../geometry/test_rectangle.py | 6 +- .../data/annotation_types/test_annotation.py | 57 +- .../data/annotation_types/test_collection.py | 69 +- .../tests/data/annotation_types/test_label.py | 274 +-- .../data/annotation_types/test_metrics.py | 242 +-- .../tests/data/annotation_types/test_ner.py | 18 +- .../data/annotation_types/test_tiled_image.py | 68 +- .../tests/data/annotation_types/test_video.py | 15 +- libs/labelbox/tests/data/conftest.py | 47 +- libs/labelbox/tests/data/export/conftest.py | 568 +++--- .../data/export/legacy/test_export_catalog.py | 10 +- .../data/export/legacy/test_export_dataset.py | 26 +- .../export/legacy/test_export_model_run.py | 23 +- .../data/export/legacy/test_export_project.py | 181 +- .../data/export/legacy/test_export_slice.py | 10 +- .../data/export/legacy/test_export_video.py | 275 +-- .../data/export/legacy/test_legacy_export.py | 179 +- .../test_export_data_rows_streamable.py | 86 +- .../test_export_dataset_streamable.py | 68 +- .../test_export_embeddings_streamable.py | 74 +- .../test_export_model_run_streamable.py | 28 +- .../test_export_project_streamable.py | 208 ++- .../test_export_video_streamable.py | 108 +- .../data/metrics/confusion_matrix/conftest.py | 598 ++++--- .../test_confusion_matrix_data_row.py | 63 +- .../test_confusion_matrix_feature.py | 54 +- .../data/metrics/iou/data_row/conftest.py | 1262 ++++++------- .../data/metrics/iou/feature/conftest.py | 301 ++-- .../metrics/iou/feature/test_feature_iou.py | 3 +- .../data/serialization/coco/test_coco.py | 26 +- .../serialization/ndjson/test_checklist.py | 408 +++-- .../ndjson/test_classification.py | 10 +- .../serialization/ndjson/test_conversation.py | 194 +- .../serialization/ndjson/test_data_gen.py | 54 +- .../data/serialization/ndjson/test_dicom.py | 197 +- .../serialization/ndjson/test_document.py | 56 +- .../ndjson/test_export_video_objects.py | 1140 ++++++------ .../serialization/ndjson/test_free_text.py | 113 +- .../serialization/ndjson/test_global_key.py | 33 +- .../data/serialization/ndjson/test_image.py | 91 +- .../data/serialization/ndjson/test_metric.py | 22 +- .../data/serialization/ndjson/test_mmc.py | 13 +- .../ndjson/test_ndlabel_subclass_matching.py | 12 +- .../data/serialization/ndjson/test_nested.py | 7 +- .../serialization/ndjson/test_polyline.py | 13 +- .../data/serialization/ndjson/test_radio.py | 104 +- .../serialization/ndjson/test_rectangle.py | 48 +- .../serialization/ndjson/test_relationship.py | 12 +- .../data/serialization/ndjson/test_text.py | 41 +- .../serialization/ndjson/test_text_entity.py | 13 +- .../data/serialization/ndjson/test_video.py | 760 ++++---- .../tests/data/test_data_row_metadata.py | 287 +-- .../tests/data/test_prefetch_generator.py | 3 +- libs/labelbox/tests/integration/conftest.py | 576 +++--- .../integration/schema/test_user_group.py | 16 +- libs/labelbox/tests/integration/test_batch.py | 189 +- .../tests/integration/test_batches.py | 12 +- .../test_chat_evaluation_ontology_project.py | 110 +- .../tests/integration/test_client_errors.py | 13 +- .../test_data_row_delete_metadata.py | 218 ++- .../tests/integration/test_data_rows.py | 855 +++++---- .../integration/test_data_rows_upsert.py | 309 ++-- .../tests/integration/test_dataset.py | 78 +- .../integration/test_delegated_access.py | 108 +- .../tests/integration/test_embedding.py | 21 +- .../tests/integration/test_ephemeral.py | 12 +- .../tests/integration/test_feature_schema.py | 56 +- .../tests/integration/test_filtering.py | 11 +- .../tests/integration/test_foundry.py | 103 +- .../tests/integration/test_global_keys.py | 236 ++- libs/labelbox/tests/integration/test_label.py | 16 +- .../integration/test_labeling_dashboard.py | 82 +- .../integration/test_labeling_frontend.py | 6 +- .../test_labeling_parameter_overrides.py | 47 +- .../integration/test_labeling_service.py | 43 +- .../tests/integration/test_legacy_project.py | 15 +- .../tests/integration/test_model_config.py | 12 +- .../test_offline_chat_evaluation_project.py | 13 +- .../tests/integration/test_ontology.py | 251 +-- .../tests/integration/test_project.py | 119 +- .../integration/test_project_model_config.py | 63 +- .../test_project_set_model_setup_complete.py | 34 +- .../tests/integration/test_project_setup.py | 32 +- ...test_prompt_response_generation_project.py | 162 +- .../test_response_creation_project.py | 18 +- .../integration/test_send_to_annotate.py | 39 +- libs/labelbox/tests/integration/test_task.py | 68 +- .../tests/integration/test_task_queue.py | 30 +- .../tests/integration/test_user_and_org.py | 2 +- .../tests/integration/test_user_management.py | 107 +- .../tests/integration/test_webhook.py | 27 +- libs/labelbox/tests/unit/conftest.py | 145 +- .../unit/export_task/test_export_task.py | 105 +- .../export_task/test_unit_file_converter.py | 11 +- .../test_unit_file_retriever_by_line.py | 55 +- .../test_unit_file_retriever_by_offset.py | 37 +- .../export_task/test_unit_json_converter.py | 18 +- .../tests/unit/schema/test_user_group.py | 197 +- .../tests/unit/test_annotation_import.py | 48 +- .../tests/unit/test_data_row_upsert_data.py | 119 +- libs/labelbox/tests/unit/test_exceptions.py | 21 +- .../tests/unit/test_label_data_type.py | 37 +- libs/labelbox/tests/unit/test_mal_import.py | 85 +- .../tests/unit/test_ndjson_parsing.py | 2 +- libs/labelbox/tests/unit/test_project.py | 24 +- ...est_unit_delete_batch_data_row_metadata.py | 46 +- .../unit/test_unit_descriptor_file_creator.py | 35 +- .../tests/unit/test_unit_entity_meta.py | 14 +- .../tests/unit/test_unit_export_filters.py | 43 +- .../tests/unit/test_unit_label_import.py | 45 +- .../labelbox/tests/unit/test_unit_ontology.py | 306 ++-- .../tests/unit/test_unit_ontology_kind.py | 15 +- ...t_validate_labeling_parameter_overrides.py | 14 +- libs/labelbox/tests/unit/test_unit_query.py | 15 +- .../tests/unit/test_unit_search_filters.py | 155 +- libs/labelbox/tests/unit/test_unit_webhook.py | 9 +- libs/labelbox/tests/unit/test_utils.py | 35 +- libs/labelbox/tests/utils.py | 6 +- 271 files changed, 16955 insertions(+), 13067 deletions(-) diff --git a/libs/labelbox/src/labelbox/__init__.py b/libs/labelbox/src/labelbox/__init__.py index ac7efdc96..633e8f4c2 100644 --- a/libs/labelbox/src/labelbox/__init__.py +++ b/libs/labelbox/src/labelbox/__init__.py @@ -7,7 +7,12 @@ from labelbox.schema.model import Model from labelbox.schema.model_config import ModelConfig from labelbox.schema.bulk_import_request import BulkImportRequest -from labelbox.schema.annotation_import import MALPredictionImport, MEAPredictionImport, LabelImport, MEAToMALPredictionImport +from labelbox.schema.annotation_import import ( + MALPredictionImport, + MEAPredictionImport, + LabelImport, + MEAToMALPredictionImport, +) from labelbox.schema.dataset import Dataset from labelbox.schema.data_row import DataRow from labelbox.schema.catalog import Catalog @@ -18,16 +23,39 @@ from labelbox.schema.user import User from labelbox.schema.organization import Organization from labelbox.schema.task import Task -from labelbox.schema.export_task import StreamType, ExportTask, JsonConverter, JsonConverterOutput, FileConverter, FileConverterOutput, BufferedJsonConverterOutput -from labelbox.schema.labeling_frontend import LabelingFrontend, LabelingFrontendOptions +from labelbox.schema.export_task import ( + StreamType, + ExportTask, + JsonConverter, + JsonConverterOutput, + FileConverter, + FileConverterOutput, + BufferedJsonConverterOutput, +) +from labelbox.schema.labeling_frontend import ( + LabelingFrontend, + LabelingFrontendOptions, +) from labelbox.schema.asset_attachment import AssetAttachment from labelbox.schema.webhook import Webhook -from labelbox.schema.ontology import Ontology, OntologyBuilder, Classification, Option, Tool, FeatureSchema +from labelbox.schema.ontology import ( + Ontology, + OntologyBuilder, + Classification, + Option, + Tool, + FeatureSchema, +) from labelbox.schema.ontology import PromptResponseClassification from labelbox.schema.ontology import ResponseOption from labelbox.schema.role import Role, ProjectRole from labelbox.schema.invite import Invite, InviteLimit -from labelbox.schema.data_row_metadata import DataRowMetadataOntology, DataRowMetadataField, DataRowMetadata, DeleteDataRowMetadata +from labelbox.schema.data_row_metadata import ( + DataRowMetadataOntology, + DataRowMetadataField, + DataRowMetadata, + DeleteDataRowMetadata, +) from labelbox.schema.model_run import ModelRun, DataSplit from labelbox.schema.benchmark import Benchmark from labelbox.schema.iam_integration import IAMIntegration @@ -42,7 +70,10 @@ from labelbox.schema.identifiables import UniqueIds, GlobalKeys, DataRowIds from labelbox.schema.identifiable import UniqueId, GlobalKey from labelbox.schema.ontology_kind import OntologyKind -from labelbox.schema.project_overview import ProjectOverview, ProjectOverviewDetailed +from labelbox.schema.project_overview import ( + ProjectOverview, + ProjectOverviewDetailed, +) from labelbox.schema.labeling_service import LabelingService from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard from labelbox.schema.labeling_service_status import LabelingServiceStatus diff --git a/libs/labelbox/src/labelbox/adv_client.py b/libs/labelbox/src/labelbox/adv_client.py index 6eab78d68..626ac0279 100644 --- a/libs/labelbox/src/labelbox/adv_client.py +++ b/libs/labelbox/src/labelbox/adv_client.py @@ -12,7 +12,6 @@ class AdvClient: - def __init__(self, endpoint: str, api_key: str): self.endpoint = endpoint self.api_key = api_key @@ -32,8 +31,9 @@ def get_embeddings(self) -> List[Dict[str, Any]]: return self._request("GET", "/adv/v1/embeddings").get("results", []) def import_vectors_from_file(self, id: str, file_path: str, callback=None): - self._send_ndjson(f"/adv/v1/embeddings/{id}/_import_ndjson", file_path, - callback) + self._send_ndjson( + f"/adv/v1/embeddings/{id}/_import_ndjson", file_path, callback + ) def get_imported_vector_count(self, id: str) -> int: data = self._request("GET", f"/adv/v1/embeddings/{id}/vectors/_count") @@ -41,38 +41,42 @@ def get_imported_vector_count(self, id: str) -> int: def _create_session(self) -> Session: session = requests.session() - session.headers.update({ - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - }) + session.headers.update( + { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + ) return session - def _request(self, - method: str, - path: str, - data: Optional[Dict[str, Any]] = None, - headers: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + def _request( + self, + method: str, + path: str, + data: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: url = f"{self.endpoint}{path}" requests_data = None if data: requests_data = json.dumps(data) - response = self.session.request(method, - url, - data=requests_data, - headers=headers) + response = self.session.request( + method, url, data=requests_data, headers=headers + ) if response.status_code != requests.codes.ok: - message = response.json().get('message') + message = response.json().get("message") if message: raise LabelboxError(message) else: response.raise_for_status() return response.json() - def _send_ndjson(self, - path: str, - file_path: str, - callback: Optional[Callable[[Dict[str, Any]], - None]] = None): + def _send_ndjson( + self, + path: str, + file_path: str, + callback: Optional[Callable[[Dict[str, Any]], None]] = None, + ): """ Sends an NDJson file in chunks. @@ -87,7 +91,7 @@ def upload_chunk(_buffer, _count): _headers = { "Content-Type": "application/x-ndjson", "X-Content-Lines": str(_count), - "Content-Length": str(buffer.tell()) + "Content-Length": str(buffer.tell()), } rsp = self._send_bytes(f"{self.endpoint}{path}", _buffer, _headers) rsp.raise_for_status() @@ -96,7 +100,7 @@ def upload_chunk(_buffer, _count): buffer = io.BytesIO() count = 0 - with open(file_path, 'rb') as fp: + with open(file_path, "rb") as fp: for line in fp: buffer.write(line) count += 1 @@ -107,10 +111,12 @@ def upload_chunk(_buffer, _count): if count: upload_chunk(buffer, count) - def _send_bytes(self, - url: str, - buffer: io.BytesIO, - headers: Optional[Dict[str, Any]] = None) -> Response: + def _send_bytes( + self, + url: str, + buffer: io.BytesIO, + headers: Optional[Dict[str, Any]] = None, + ) -> Response: buffer.seek(0) return self.session.put(url, headers=headers, data=buffer) diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index 431ddbdc4..cda55c282 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -26,7 +26,9 @@ from labelbox.orm.model import Entity, Field from labelbox.pagination import PaginatedCollection from labelbox.schema import role -from labelbox.schema.conflict_resolution_strategy import ConflictResolutionStrategy +from labelbox.schema.conflict_resolution_strategy import ( + ConflictResolutionStrategy, +) from labelbox.schema.data_row import DataRow from labelbox.schema.catalog import Catalog from labelbox.schema.data_row_metadata import DataRowMetadataOntology @@ -38,26 +40,46 @@ from labelbox.schema.identifiables import DataRowIds from labelbox.schema.identifiables import GlobalKeys from labelbox.schema.labeling_frontend import LabelingFrontend -from labelbox.schema.media_type import MediaType, get_media_type_validation_error +from labelbox.schema.media_type import ( + MediaType, + get_media_type_validation_error, +) from labelbox.schema.model import Model from labelbox.schema.model_config import ModelConfig from labelbox.schema.model_run import ModelRun from labelbox.schema.ontology import Ontology, DeleteFeatureFromOntologyResult -from labelbox.schema.ontology import Tool, Classification, FeatureSchema, PromptResponseClassification +from labelbox.schema.ontology import ( + Tool, + Classification, + FeatureSchema, + PromptResponseClassification, +) from labelbox.schema.organization import Organization from labelbox.schema.project import Project -from labelbox.schema.quality_mode import QualityMode, BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS, \ - BENCHMARK_AUTO_AUDIT_PERCENTAGE, CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS, CONSENSUS_AUTO_AUDIT_PERCENTAGE +from labelbox.schema.quality_mode import ( + QualityMode, + BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS, + BENCHMARK_AUTO_AUDIT_PERCENTAGE, + CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS, + CONSENSUS_AUTO_AUDIT_PERCENTAGE, +) from labelbox.schema.queue_mode import QueueMode from labelbox.schema.role import Role -from labelbox.schema.send_to_annotate_params import SendToAnnotateFromCatalogParams, build_destination_task_queue_input, \ - build_predictions_input, build_annotations_input +from labelbox.schema.send_to_annotate_params import ( + SendToAnnotateFromCatalogParams, + build_destination_task_queue_input, + build_predictions_input, + build_annotations_input, +) from labelbox.schema.slice import CatalogSlice, ModelSlice from labelbox.schema.task import Task, DataUpsertTask from labelbox.schema.user import User from labelbox.schema.label_score import LabelScore -from labelbox.schema.ontology_kind import (OntologyKind, EditorTaskTypeMapper, - EditorTaskType) +from labelbox.schema.ontology_kind import ( + OntologyKind, + EditorTaskTypeMapper, + EditorTaskType, +) from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard logger = logging.getLogger(__name__) @@ -72,20 +94,22 @@ def python_version_info(): class Client: - """ A Labelbox client. + """A Labelbox client. Contains info necessary for connecting to a Labelbox server (URL, authentication key). Provides functions for querying and creating top-level data objects (Projects, Datasets). """ - def __init__(self, - api_key=None, - endpoint='https://api.labelbox.com/graphql', - enable_experimental=False, - app_url="https://app.labelbox.com", - rest_endpoint="https://api.labelbox.com/api/v1"): - """ Creates and initializes a Labelbox Client. + def __init__( + self, + api_key=None, + endpoint="https://api.labelbox.com/graphql", + enable_experimental=False, + app_url="https://app.labelbox.com", + rest_endpoint="https://api.labelbox.com/api/v1", + ): + """Creates and initializes a Labelbox Client. Logging is defaulted to level WARNING. To receive more verbose output to console, update `logging.level` to the appropriate level. @@ -106,7 +130,8 @@ def __init__(self, if api_key is None: if _LABELBOX_API_KEY not in os.environ: raise labelbox.exceptions.AuthenticationError( - "Labelbox API key not provided") + "Labelbox API key not provided" + ) api_key = os.environ[_LABELBOX_API_KEY] self.api_key = api_key @@ -123,7 +148,8 @@ def __init__(self, self._connection: requests.Session = self._init_connection() def _init_connection(self) -> requests.Session: - connection = requests.Session( + connection = ( + requests.Session() ) # using default connection pool size of 10 connection.headers.update(self._default_headers()) @@ -135,26 +161,31 @@ def headers(self) -> MappingProxyType: def _default_headers(self): return { - 'Authorization': 'Bearer %s' % self.api_key, - 'Accept': 'application/json', - 'Content-Type': 'application/json', - 'X-User-Agent': f"python-sdk {SDK_VERSION}", - 'X-Python-Version': f"{python_version_info()}", + "Authorization": "Bearer %s" % self.api_key, + "Accept": "application/json", + "Content-Type": "application/json", + "X-User-Agent": f"python-sdk {SDK_VERSION}", + "X-Python-Version": f"{python_version_info()}", } - @retry.Retry(predicate=retry.if_exception_type( - labelbox.exceptions.InternalServerError, - labelbox.exceptions.TimeoutError)) - def execute(self, - query=None, - params=None, - data=None, - files=None, - timeout=60.0, - experimental=False, - error_log_key="message", - raise_return_resource_not_found=False): - """ Sends a request to the server for the execution of the + @retry.Retry( + predicate=retry.if_exception_type( + labelbox.exceptions.InternalServerError, + labelbox.exceptions.TimeoutError, + ) + ) + def execute( + self, + query=None, + params=None, + data=None, + files=None, + timeout=60.0, + experimental=False, + error_log_key="message", + raise_return_resource_not_found=False, + ): + """Sends a request to the server for the execution of the given query. Checks the response for errors and wraps errors @@ -199,26 +230,30 @@ def convert_value(value): params = { key: convert_value(value) for key, value in params.items() } - data = json.dumps({ - 'query': query, - 'variables': params - }).encode('utf-8') + data = json.dumps({"query": query, "variables": params}).encode( + "utf-8" + ) elif data is None: raise ValueError("query and data cannot both be none") - endpoint = self.endpoint if not experimental else self.endpoint.replace( - "/graphql", "/_gql") + endpoint = ( + self.endpoint + if not experimental + else self.endpoint.replace("/graphql", "/_gql") + ) try: headers = self._connection.headers.copy() if files: - del headers['Content-Type'] - del headers['Accept'] - request = requests.Request('POST', - endpoint, - headers=headers, - data=data, - files=files if files else None) + del headers["Content-Type"] + del headers["Accept"] + request = requests.Request( + "POST", + endpoint, + headers=headers, + data=data, + files=files if files else None, + ) prepped: requests.PreparedRequest = request.prepare() @@ -231,20 +266,30 @@ def convert_value(value): raise labelbox.exceptions.NetworkError(e) except Exception as e: raise labelbox.exceptions.LabelboxError( - "Unknown error during Client.query(): " + str(e), e) + "Unknown error during Client.query(): " + str(e), e + ) - if 200 <= response.status_code < 300 or response.status_code < 500 or response.status_code >= 600: + if ( + 200 <= response.status_code < 300 + or response.status_code < 500 + or response.status_code >= 600 + ): try: r_json = response.json() except Exception: raise labelbox.exceptions.LabelboxError( - "Failed to parse response as JSON: %s" % response.text) + "Failed to parse response as JSON: %s" % response.text + ) else: - if "upstream connect error or disconnect/reset before headers" in response.text: + if ( + "upstream connect error or disconnect/reset before headers" + in response.text + ): raise labelbox.exceptions.InternalServerError( - "Connection reset") + "Connection reset" + ) elif response.status_code == 502: - error_502 = '502 Bad Gateway' + error_502 = "502 Bad Gateway" raise labelbox.exceptions.InternalServerError(error_502) elif 500 <= response.status_code < 600: error_500 = f"Internal server http error {response.status_code}" @@ -253,7 +298,7 @@ def convert_value(value): errors = r_json.get("errors", []) def check_errors(keywords, *path): - """ Helper that looks for any of the given `keywords` in any of + """Helper that looks for any of the given `keywords` in any of current errors on paths (like error[path][component][to][keyword]). """ for error in errors: @@ -270,18 +315,23 @@ def get_error_status_code(error: dict) -> int: except: return 500 - if check_errors(["AUTHENTICATION_ERROR"], "extensions", - "code") is not None: + if ( + check_errors(["AUTHENTICATION_ERROR"], "extensions", "code") + is not None + ): raise labelbox.exceptions.AuthenticationError("Invalid API key") - authorization_error = check_errors(["AUTHORIZATION_ERROR"], - "extensions", "code") + authorization_error = check_errors( + ["AUTHORIZATION_ERROR"], "extensions", "code" + ) if authorization_error is not None: raise labelbox.exceptions.AuthorizationError( - authorization_error["message"]) + authorization_error["message"] + ) - validation_error = check_errors(["GRAPHQL_VALIDATION_FAILED"], - "extensions", "code") + validation_error = check_errors( + ["GRAPHQL_VALIDATION_FAILED"], "extensions", "code" + ) if validation_error is not None: message = validation_error["message"] @@ -290,11 +340,13 @@ def get_error_status_code(error: dict) -> int: else: raise labelbox.exceptions.InvalidQueryError(message) - graphql_error = check_errors(["GRAPHQL_PARSE_FAILED"], "extensions", - "code") + graphql_error = check_errors( + ["GRAPHQL_PARSE_FAILED"], "extensions", "code" + ) if graphql_error is not None: raise labelbox.exceptions.InvalidQueryError( - graphql_error["message"]) + graphql_error["message"] + ) # Check if API limit was exceeded response_msg = r_json.get("message", "") @@ -302,34 +354,41 @@ def get_error_status_code(error: dict) -> int: if response_msg.startswith("You have exceeded"): raise labelbox.exceptions.ApiLimitError(response_msg) - resource_not_found_error = check_errors(["RESOURCE_NOT_FOUND"], - "extensions", "code") + resource_not_found_error = check_errors( + ["RESOURCE_NOT_FOUND"], "extensions", "code" + ) if resource_not_found_error is not None: if raise_return_resource_not_found: raise labelbox.exceptions.ResourceNotFoundError( - message=resource_not_found_error["message"]) + message=resource_not_found_error["message"] + ) else: # Return None and let the caller methods raise an exception # as they already know which resource type and ID was requested return None - resource_conflict_error = check_errors(["RESOURCE_CONFLICT"], - "extensions", "code") + resource_conflict_error = check_errors( + ["RESOURCE_CONFLICT"], "extensions", "code" + ) if resource_conflict_error is not None: raise labelbox.exceptions.ResourceConflict( - resource_conflict_error["message"]) + resource_conflict_error["message"] + ) - malformed_request_error = check_errors(["MALFORMED_REQUEST"], - "extensions", "code") + malformed_request_error = check_errors( + ["MALFORMED_REQUEST"], "extensions", "code" + ) if malformed_request_error is not None: raise labelbox.exceptions.MalformedQueryException( - malformed_request_error[error_log_key]) + malformed_request_error[error_log_key] + ) # A lot of different error situations are now labeled serverside # as INTERNAL_SERVER_ERROR, when they are actually client errors. # TODO: fix this in the server API - internal_server_error = check_errors(["INTERNAL_SERVER_ERROR"], - "extensions", "code") + internal_server_error = check_errors( + ["INTERNAL_SERVER_ERROR"], "extensions", "code" + ) if internal_server_error is not None: message = internal_server_error.get("message") error_status_code = get_error_status_code(internal_server_error) @@ -344,8 +403,9 @@ def get_error_status_code(error: dict) -> int: else: raise labelbox.exceptions.InternalServerError(message) - not_allowed_error = check_errors(["OPERATION_NOT_ALLOWED"], - "extensions", "code") + not_allowed_error = check_errors( + ["OPERATION_NOT_ALLOWED"], "extensions", "code" + ) if not_allowed_error is not None: message = not_allowed_error.get("message") raise labelbox.exceptions.OperationNotAllowedException(message) @@ -356,10 +416,14 @@ def get_error_status_code(error: dict) -> int: map( lambda x: { "message": x["message"], - "code": x["extensions"]["code"] - }, errors)) - raise labelbox.exceptions.LabelboxError("Unknown error: %s" % - str(messages)) + "code": x["extensions"]["code"], + }, + errors, + ) + ) + raise labelbox.exceptions.LabelboxError( + "Unknown error: %s" % str(messages) + ) # if we do return a proper error code, and didn't catch this above # reraise @@ -368,7 +432,7 @@ def get_error_status_code(error: dict) -> int: # in the SDK if response.status_code != requests.codes.ok: message = f"{response.status_code} {response.reason}" - cause = r_json.get('message') + cause = r_json.get("message") raise labelbox.exceptions.LabelboxError(message, cause) return r_json["data"] @@ -388,18 +452,23 @@ def upload_file(self, path: str) -> str: content_type, _ = mimetypes.guess_type(path) filename = os.path.basename(path) with open(path, "rb") as f: - return self.upload_data(content=f.read(), - filename=filename, - content_type=content_type) - - @retry.Retry(predicate=retry.if_exception_type( - labelbox.exceptions.InternalServerError)) - def upload_data(self, - content: bytes, - filename: str = None, - content_type: str = None, - sign: bool = False) -> str: - """ Uploads the given data (bytes) to Labelbox. + return self.upload_data( + content=f.read(), filename=filename, content_type=content_type + ) + + @retry.Retry( + predicate=retry.if_exception_type( + labelbox.exceptions.InternalServerError + ) + ) + def upload_data( + self, + content: bytes, + filename: str = None, + content_type: str = None, + sign: bool = False, + ) -> str: + """Uploads the given data (bytes) to Labelbox. Args: content: bytestring to upload @@ -415,40 +484,43 @@ def upload_data(self, """ request_data = { - "operations": - json.dumps({ + "operations": json.dumps( + { "variables": { "file": None, "contentLength": len(content), - "sign": sign + "sign": sign, }, - "query": - """mutation UploadFile($file: Upload!, $contentLength: Int!, + "query": """mutation UploadFile($file: Upload!, $contentLength: Int!, $sign: Boolean) { uploadFile(file: $file, contentLength: $contentLength, sign: $sign) {url filename} } """, - }), + } + ), "map": (None, json.dumps({"1": ["variables.file"]})), } files = { - "1": (filename, content, content_type) if - (filename and content_type) else content + "1": (filename, content, content_type) + if (filename and content_type) + else content } headers = self._connection.headers.copy() headers.pop("Content-Type", None) - request = requests.Request('POST', - self.endpoint, - headers=headers, - data=request_data, - files=files) + request = requests.Request( + "POST", + self.endpoint, + headers=headers, + data=request_data, + files=files, + ) prepped: requests.PreparedRequest = request.prepare() response = self._connection.send(prepped) if response.status_code == 502: - error_502 = '502 Bad Gateway' + error_502 = "502 Bad Gateway" raise labelbox.exceptions.InternalServerError(error_502) elif response.status_code == 503: raise labelbox.exceptions.InternalServerError(response.text) @@ -459,22 +531,25 @@ def upload_data(self, file_data = response.json().get("data", None) except ValueError as e: # response is not valid JSON raise labelbox.exceptions.LabelboxError( - "Failed to upload, unknown cause", e) + "Failed to upload, unknown cause", e + ) if not file_data or not file_data.get("uploadFile", None): try: errors = response.json().get("errors", []) - error_msg = next(iter(errors), {}).get("message", - "Unknown error") + error_msg = next(iter(errors), {}).get( + "message", "Unknown error" + ) except Exception as e: error_msg = "Unknown error" raise labelbox.exceptions.LabelboxError( - "Failed to upload, message: %s" % error_msg) + "Failed to upload, message: %s" % error_msg + ) return file_data["uploadFile"]["url"] def _get_single(self, db_object_type, uid): - """ Fetches a single object of the given type, for the given ID. + """Fetches a single object of the given type, for the given ID. Args: db_object_type (type): DbObject subclass. @@ -491,12 +566,13 @@ def _get_single(self, db_object_type, uid): res = res and res.get(utils.camel_case(db_object_type.type_name())) if res is None: raise labelbox.exceptions.ResourceNotFoundError( - db_object_type, params) + db_object_type, params + ) else: return db_object_type(self, res) def get_project(self, project_id) -> Project: - """ Gets a single Project with the given ID. + """Gets a single Project with the given ID. >>> project = client.get_project("") @@ -511,7 +587,7 @@ def get_project(self, project_id) -> Project: return self._get_single(Entity.Project, project_id) def get_dataset(self, dataset_id) -> Dataset: - """ Gets a single Dataset with the given ID. + """Gets a single Dataset with the given ID. >>> dataset = client.get_dataset("") @@ -526,21 +602,21 @@ def get_dataset(self, dataset_id) -> Dataset: return self._get_single(Entity.Dataset, dataset_id) def get_user(self) -> User: - """ Gets the current User database object. + """Gets the current User database object. >>> user = client.get_user() """ return self._get_single(Entity.User, None) def get_organization(self) -> Organization: - """ Gets the Organization DB object of the current user. + """Gets the Organization DB object of the current user. >>> organization = client.get_organization() """ return self._get_single(Entity.Organization, None) def _get_all(self, db_object_type, where, filter_deleted=True): - """ Fetches all the objects of the given type the user has access to. + """Fetches all the objects of the given type the user has access to. Args: db_object_type (type): DbObject subclass. @@ -555,12 +631,15 @@ def _get_all(self, db_object_type, where, filter_deleted=True): query_str, params = query.get_all(db_object_type, where) return PaginatedCollection( - self, query_str, params, + self, + query_str, + params, [utils.camel_case(db_object_type.type_name()) + "s"], - db_object_type) + db_object_type, + ) def get_projects(self, where=None) -> PaginatedCollection: - """ Fetches all the projects the user has access to. + """Fetches all the projects the user has access to. >>> projects = client.get_projects(where=(Project.name == "") & (Project.description == "")) @@ -573,7 +652,7 @@ def get_projects(self, where=None) -> PaginatedCollection: return self._get_all(Entity.Project, where) def get_users(self, where=None) -> PaginatedCollection: - """ Fetches all the users. + """Fetches all the users. >>> users = client.get_users(where=User.email == "") @@ -586,7 +665,7 @@ def get_users(self, where=None) -> PaginatedCollection: return self._get_all(Entity.User, where, filter_deleted=False) def get_datasets(self, where=None) -> PaginatedCollection: - """ Fetches one or more datasets. + """Fetches one or more datasets. >>> datasets = client.get_datasets(where=(Dataset.name == "") & (Dataset.description == "")) @@ -599,7 +678,7 @@ def get_datasets(self, where=None) -> PaginatedCollection: return self._get_all(Entity.Dataset, where) def get_labeling_frontends(self, where=None) -> List[LabelingFrontend]: - """ Fetches all the labeling frontends. + """Fetches all the labeling frontends. >>> frontend = client.get_labeling_frontends(where=LabelingFrontend.name == "Editor") @@ -612,7 +691,7 @@ def get_labeling_frontends(self, where=None) -> List[LabelingFrontend]: return self._get_all(Entity.LabelingFrontend, where) def _create(self, db_object_type, data, extra_params={}): - """ Creates an object on the server. Attribute values are + """Creates an object on the server. Attribute values are passed as keyword arguments: Args: @@ -630,8 +709,9 @@ def _create(self, db_object_type, data, extra_params={}): # Convert string attribute names to Field or Relationship objects. # Also convert Labelbox object values to their UIDs. data = { - db_object_type.attribute(attr) if isinstance(attr, str) else attr: - value.uid if isinstance(value, DbObject) else value + db_object_type.attribute(attr) + if isinstance(attr, str) + else attr: value.uid if isinstance(value, DbObject) else value for attr, value in data.items() } @@ -640,15 +720,17 @@ def _create(self, db_object_type, data, extra_params={}): res = self.execute(query_string, params) if not res: - raise labelbox.exceptions.LabelboxError("Failed to create %s" % - db_object_type.type_name()) + raise labelbox.exceptions.LabelboxError( + "Failed to create %s" % db_object_type.type_name() + ) res = res["create%s" % db_object_type.type_name()] return db_object_type(self, res) - def create_model_config(self, name: str, model_id: str, - inference_params: dict) -> ModelConfig: - """ Creates a new model config with the given params. + def create_model_config( + self, name: str, model_id: str, inference_params: dict + ) -> ModelConfig: + """Creates a new model config with the given params. Model configs are scoped to organizations, and can be reused between projects. Args: @@ -673,13 +755,13 @@ def create_model_config(self, name: str, model_id: str, params = { "modelId": model_id, "inferenceParams": inference_params, - "name": name + "name": name, } result = self.execute(query, params) - return ModelConfig(self, result['createModelConfig']) + return ModelConfig(self, result["createModelConfig"]) def delete_model_config(self, id: str) -> bool: - """ Deletes an existing model config with the given id + """Deletes an existing model config with the given id Args: id (str): ID of existing model config @@ -697,13 +779,14 @@ def delete_model_config(self, id: str) -> bool: result = self.execute(query, params) if not result: raise labelbox.exceptions.ResourceNotFoundError( - Entity.ModelConfig, params) - return result['deleteModelConfig']['success'] + Entity.ModelConfig, params + ) + return result["deleteModelConfig"]["success"] - def create_dataset(self, - iam_integration=IAMIntegration._DEFAULT, - **kwargs) -> Dataset: - """ Creates a Dataset object on the server. + def create_dataset( + self, iam_integration=IAMIntegration._DEFAULT, **kwargs + ) -> Dataset: + """Creates a Dataset object on the server. Attribute values are passed as keyword arguments. @@ -724,8 +807,9 @@ def create_dataset(self, """ dataset = self._create(Entity.Dataset, kwargs) if iam_integration == IAMIntegration._DEFAULT: - iam_integration = self.get_organization( - ).get_default_iam_integration() + iam_integration = ( + self.get_organization().get_default_iam_integration() + ) if iam_integration is None: return dataset @@ -738,21 +822,23 @@ def create_dataset(self, if not iam_integration.valid: raise ValueError( - "Integration is not valid. Please select another.") + "Integration is not valid. Please select another." + ) self.execute( """mutation setSignerForDatasetPyApi($signerId: ID!, $datasetId: ID!) { setSignerForDataset(data: { signerId: $signerId}, where: {id: $datasetId}){id}} - """, { - 'signerId': iam_integration.uid, - 'datasetId': dataset.uid - }) + """, + {"signerId": iam_integration.uid, "datasetId": dataset.uid}, + ) validation_result = self.execute( """mutation validateDatasetPyApi($id: ID!){validateDataset(where: {id : $id}){ valid checks{name, success}}} - """, {'id': dataset.uid}) + """, + {"id": dataset.uid}, + ) - if not validation_result['validateDataset']['valid']: + if not validation_result["validateDataset"]["valid"]: raise labelbox.exceptions.LabelboxError( f"IAMIntegration was not successfully added to the dataset." ) @@ -762,7 +848,7 @@ def create_dataset(self, return dataset def create_project(self, **kwargs) -> Project: - """ Creates a Project object on the server. + """Creates a Project object on the server. Attribute values are passed as keyword arguments. @@ -800,26 +886,32 @@ def create_project(self, **kwargs) -> Project: return self._create_project(**kwargs) @overload - def create_model_evaluation_project(self, - dataset_name: str, - dataset_id: str = None, - data_row_count: int = 100, - **kwargs) -> Project: + def create_model_evaluation_project( + self, + dataset_name: str, + dataset_id: str = None, + data_row_count: int = 100, + **kwargs, + ) -> Project: pass @overload - def create_model_evaluation_project(self, - dataset_id: str, - dataset_name: str = None, - data_row_count: int = 100, - **kwargs) -> Project: + def create_model_evaluation_project( + self, + dataset_id: str, + dataset_name: str = None, + data_row_count: int = 100, + **kwargs, + ) -> Project: pass - def create_model_evaluation_project(self, - dataset_id: Optional[str] = None, - dataset_name: Optional[str] = None, - data_row_count: int = 100, - **kwargs) -> Project: + def create_model_evaluation_project( + self, + dataset_id: Optional[str] = None, + dataset_name: Optional[str] = None, + data_row_count: int = 100, + **kwargs, + ) -> Project: """ Use this method exclusively to create a chat model evaluation project. Args: @@ -875,10 +967,12 @@ def create_offline_model_evaluation_project(self, **kwargs) -> Project: Returns: Project: The created project """ - kwargs[ - "media_type"] = MediaType.Conversational # Only Conversational is supported - kwargs[ - "editor_task_type"] = EditorTaskType.OfflineModelChatEvaluation.value # Special editor task type for offline model evaluation + kwargs["media_type"] = ( + MediaType.Conversational + ) # Only Conversational is supported + kwargs["editor_task_type"] = ( + EditorTaskType.OfflineModelChatEvaluation.value + ) # Special editor task type for offline model evaluation # The following arguments are not supported for offline model evaluation kwargs.pop("dataset_name_or_id", None) @@ -888,11 +982,12 @@ def create_offline_model_evaluation_project(self, **kwargs) -> Project: return self._create_project(**kwargs) def create_prompt_response_generation_project( - self, - dataset_id: Optional[str] = None, - dataset_name: Optional[str] = None, - data_row_count: int = 100, - **kwargs) -> Project: + self, + dataset_id: Optional[str] = None, + dataset_name: Optional[str] = None, + data_row_count: int = 100, + **kwargs, + ) -> Project: """ Use this method exclusively to create a prompt and response generation project. @@ -927,7 +1022,8 @@ def create_prompt_response_generation_project( if dataset_id and dataset_name: raise ValueError( - "Only provide a dataset_name or dataset_id, not both.") + "Only provide a dataset_name or dataset_id, not both." + ) if data_row_count <= 0: raise ValueError("data_row_count must be a positive integer.") @@ -940,7 +1036,8 @@ def create_prompt_response_generation_project( dataset_name_or_id = dataset_name if "media_type" in kwargs and kwargs.get("media_type") not in [ - MediaType.LLMPromptCreation, MediaType.LLMPromptResponseCreation + MediaType.LLMPromptCreation, + MediaType.LLMPromptResponseCreation, ]: raise ValueError( "media_type must be either LLMPromptCreation or LLMPromptResponseCreation" @@ -963,8 +1060,9 @@ def create_response_creation_project(self, **kwargs) -> Project: Project: The created project """ kwargs["media_type"] = MediaType.Text # Only Text is supported - kwargs[ - "editor_task_type"] = EditorTaskType.ResponseCreation.value # Special editor task type for response creation projects + kwargs["editor_task_type"] = ( + EditorTaskType.ResponseCreation.value + ) # Special editor task type for response creation projects # The following arguments are not supported for response creation projects kwargs.pop("dataset_name_or_id", None) @@ -976,7 +1074,10 @@ def create_response_creation_project(self, **kwargs) -> Project: def _create_project(self, **kwargs) -> Project: auto_audit_percentage = kwargs.get("auto_audit_percentage") auto_audit_number_of_labels = kwargs.get("auto_audit_number_of_labels") - if auto_audit_percentage is not None or auto_audit_number_of_labels is not None: + if ( + auto_audit_percentage is not None + or auto_audit_number_of_labels is not None + ): raise ValueError( "quality_modes must be set instead of auto_audit_percentage or auto_audit_number_of_labels." ) @@ -999,13 +1100,16 @@ def _create_project(self, **kwargs) -> Project: if media_type and MediaType.is_supported(media_type): media_type_value = media_type.value elif media_type: - raise TypeError(f"{media_type} is not a valid media type. Use" - f" any of {MediaType.get_supported_members()}" - " from MediaType. Example: MediaType.Image.") + raise TypeError( + f"{media_type} is not a valid media type. Use" + f" any of {MediaType.get_supported_members()}" + " from MediaType. Example: MediaType.Image." + ) else: logger.warning( "Creating a project without specifying media_type" - " through this method will soon no longer be supported.") + " through this method will soon no longer be supported." + ) media_type_value = None quality_modes = kwargs.get("quality_modes") @@ -1034,22 +1138,28 @@ def _create_project(self, **kwargs) -> Project: if quality_mode: quality_modes_set = {quality_mode} - if (quality_modes_set is None or len(quality_modes_set) == 0 or - quality_modes_set - == {QualityMode.Benchmark, QualityMode.Consensus}): - data[ - "auto_audit_number_of_labels"] = CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS + if ( + quality_modes_set is None + or len(quality_modes_set) == 0 + or quality_modes_set + == {QualityMode.Benchmark, QualityMode.Consensus} + ): + data["auto_audit_number_of_labels"] = ( + CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS + ) data["auto_audit_percentage"] = CONSENSUS_AUTO_AUDIT_PERCENTAGE data["is_benchmark_enabled"] = True data["is_consensus_enabled"] = True elif quality_modes_set == {QualityMode.Benchmark}: - data[ - "auto_audit_number_of_labels"] = BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS + data["auto_audit_number_of_labels"] = ( + BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS + ) data["auto_audit_percentage"] = BENCHMARK_AUTO_AUDIT_PERCENTAGE data["is_benchmark_enabled"] = True elif quality_modes_set == {QualityMode.Consensus}: - data[ - "auto_audit_number_of_labels"] = CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS + data["auto_audit_number_of_labels"] = ( + CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS + ) data["auto_audit_percentage"] = CONSENSUS_AUTO_AUDIT_PERCENTAGE data["is_consensus_enabled"] = True else: @@ -1062,10 +1172,12 @@ def _create_project(self, **kwargs) -> Project: params["media_type"] = media_type_value extra_params = { - Field.String("dataset_name_or_id"): - params.pop("dataset_name_or_id", None), - Field.Boolean("append_to_existing_dataset"): - params.pop("append_to_existing_dataset", None), + Field.String("dataset_name_or_id"): params.pop( + "dataset_name_or_id", None + ), + Field.Boolean("append_to_existing_dataset"): params.pop( + "append_to_existing_dataset", None + ), } extra_params = {k: v for k, v in extra_params.items() if v is not None} return self._create(Entity.Project, params, extra_params) @@ -1089,13 +1201,14 @@ def get_data_row(self, data_row_id): def get_data_row_by_global_key(self, global_key: str) -> DataRow: """ - Returns: DataRow: returns a single data row given the global key + Returns: DataRow: returns a single data row given the global key """ res = self.get_data_row_ids_for_global_keys([global_key]) - if res['status'] != "SUCCESS": + if res["status"] != "SUCCESS": raise labelbox.exceptions.ResourceNotFoundError( - Entity.DataRow, {global_key: global_key}) - data_row_id = res['results'][0] + Entity.DataRow, {global_key: global_key} + ) + data_row_id = res["results"][0] return self.get_data_row(data_row_id) @@ -1111,7 +1224,7 @@ def get_data_row_metadata_ontology(self) -> DataRowMetadataOntology: return self._data_row_metadata_ontology def get_model(self, model_id) -> Model: - """ Gets a single Model with the given ID. + """Gets a single Model with the given ID. >>> model = client.get_model("") @@ -1126,7 +1239,7 @@ def get_model(self, model_id) -> Model: return self._get_single(Entity.Model, model_id) def get_models(self, where=None) -> List[Model]: - """ Fetches all the models the user has access to. + """Fetches all the models the user has access to. >>> models = client.get_models(where=(Model.name == "")) @@ -1139,7 +1252,7 @@ def get_models(self, where=None) -> List[Model]: return self._get_all(Entity.Model, where, filter_deleted=False) def create_model(self, name, ontology_id) -> Model: - """ Creates a Model object on the server. + """Creates a Model object on the server. >>> model = client.create_model(, ) @@ -1158,14 +1271,14 @@ def create_model(self, name, ontology_id) -> Model: } }""" % query.results_query_part(Entity.Model) - result = self.execute(query_str, { - "name": name, - "ontologyId": ontology_id - }) - return Entity.Model(self, result['createModel']) + result = self.execute( + query_str, {"name": name, "ontologyId": ontology_id} + ) + return Entity.Model(self, result["createModel"]) def get_data_row_ids_for_external_ids( - self, external_ids: List[str]) -> Dict[str, List[str]]: + self, external_ids: List[str] + ) -> Dict[str, List[str]]: """ Returns a list of data row ids for a list of external ids. There is a max of 1500 items returned at a time. @@ -1183,10 +1296,10 @@ def get_data_row_ids_for_external_ids( result = defaultdict(list) for i in range(0, len(external_ids), max_ids_per_request): for row in self.execute( - query_str, - {'externalId_in': external_ids[i:i + max_ids_per_request] - })['externalIdsToDataRowIds']: - result[row['externalId']].append(row['dataRowId']) + query_str, + {"externalId_in": external_ids[i : i + max_ids_per_request]}, + )["externalIdsToDataRowIds"]: + result[row["externalId"]].append(row["dataRowId"]) return result def get_ontology(self, ontology_id) -> Ontology: @@ -1216,10 +1329,15 @@ def get_ontologies(self, name_contains) -> PaginatedCollection: } } """ % query.results_query_part(Entity.Ontology) - params = {'search': name_contains, 'filter': {'status': 'ALL'}} - return PaginatedCollection(self, query_str, params, - ['ontologies', 'nodes'], Entity.Ontology, - ['ontologies', 'nextCursor']) + params = {"search": name_contains, "filter": {"status": "ALL"}} + return PaginatedCollection( + self, + query_str, + params, + ["ontologies", "nodes"], + Entity.Ontology, + ["ontologies", "nextCursor"], + ) def get_feature_schema(self, feature_schema_id): """ @@ -1237,10 +1355,9 @@ def get_feature_schema(self, feature_schema_id): res = self.execute( query_str, - {'rootSchemaNodeWhere': { - 'featureSchemaId': feature_schema_id - }})['rootSchemaNode'] - res['id'] = res['normalized']['featureSchemaId'] + {"rootSchemaNodeWhere": {"featureSchemaId": feature_schema_id}}, + )["rootSchemaNode"] + res["id"] = res["normalized"]["featureSchemaId"] return Entity.FeatureSchema(self, res) def get_feature_schemas(self, name_contains) -> PaginatedCollection: @@ -1261,25 +1378,30 @@ def get_feature_schemas(self, name_contains) -> PaginatedCollection: } } """ % query.results_query_part(Entity.FeatureSchema) - params = {'search': name_contains, 'filter': {'status': 'ALL'}} + params = {"search": name_contains, "filter": {"status": "ALL"}} def rootSchemaPayloadToFeatureSchema(client, payload): # Technically we are querying for a Schema Node. # But the features are the same so we just grab the feature schema id - payload['id'] = payload['normalized']['featureSchemaId'] + payload["id"] = payload["normalized"]["featureSchemaId"] return Entity.FeatureSchema(client, payload) - return PaginatedCollection(self, query_str, params, - ['rootSchemaNodes', 'nodes'], - rootSchemaPayloadToFeatureSchema, - ['rootSchemaNodes', 'nextCursor']) + return PaginatedCollection( + self, + query_str, + params, + ["rootSchemaNodes", "nodes"], + rootSchemaPayloadToFeatureSchema, + ["rootSchemaNodes", "nextCursor"], + ) def create_ontology_from_feature_schemas( - self, - name, - feature_schema_ids, - media_type: MediaType = None, - ontology_kind: OntologyKind = None) -> Ontology: + self, + name, + feature_schema_ids, + media_type: MediaType = None, + ontology_kind: OntologyKind = None, + ) -> Ontology: """ Creates an ontology from a list of feature schema ids @@ -1298,22 +1420,27 @@ def create_ontology_from_feature_schemas( tools, classifications = [], [] for feature_schema_id in feature_schema_ids: feature_schema = self.get_feature_schema(feature_schema_id) - tool = ['tool'] - if 'tool' in feature_schema.normalized: - tool = feature_schema.normalized['tool'] + tool = ["tool"] + if "tool" in feature_schema.normalized: + tool = feature_schema.normalized["tool"] try: Tool.Type(tool) tools.append(feature_schema.normalized) except ValueError: raise ValueError( - f"Tool `{tool}` not in list of supported tools.") - elif 'type' in feature_schema.normalized: - classification = feature_schema.normalized['type'] - if classification in Classification.Type._value2member_map_.keys( + f"Tool `{tool}` not in list of supported tools." + ) + elif "type" in feature_schema.normalized: + classification = feature_schema.normalized["type"] + if ( + classification + in Classification.Type._value2member_map_.keys() ): Classification.Type(classification) classifications.append(feature_schema.normalized) - elif classification in PromptResponseClassification.Type._value2member_map_.keys( + elif ( + classification + in PromptResponseClassification.Type._value2member_map_.keys() ): PromptResponseClassification.Type(classification) classifications.append(feature_schema.normalized) @@ -1325,13 +1452,15 @@ def create_ontology_from_feature_schemas( raise ValueError( "Neither `tool` or `classification` found in the normalized feature schema" ) - normalized = {'tools': tools, 'classifications': classifications} + normalized = {"tools": tools, "classifications": classifications} # validation for ontology_kind and media_type is done within self.create_ontology - return self.create_ontology(name=name, - normalized=normalized, - media_type=media_type, - ontology_kind=ontology_kind) + return self.create_ontology( + name=name, + normalized=normalized, + media_type=media_type, + ontology_kind=ontology_kind, + ) def delete_unused_feature_schema(self, feature_schema_id: str) -> None: """ @@ -1342,14 +1471,18 @@ def delete_unused_feature_schema(self, feature_schema_id: str) -> None: >>> client.delete_unused_feature_schema("cleabc1my012ioqvu5anyaabc") """ - endpoint = self.rest_endpoint + "/feature-schemas/" + urllib.parse.quote( - feature_schema_id) + endpoint = ( + self.rest_endpoint + + "/feature-schemas/" + + urllib.parse.quote(feature_schema_id) + ) response = self._connection.delete(endpoint) if response.status_code != requests.codes.no_content: raise labelbox.exceptions.LabelboxError( - "Failed to delete the feature schema, message: " + - str(response.json()['message'])) + "Failed to delete the feature schema, message: " + + str(response.json()["message"]) + ) def delete_unused_ontology(self, ontology_id: str) -> None: """ @@ -1359,17 +1492,22 @@ def delete_unused_ontology(self, ontology_id: str) -> None: Example: >>> client.delete_unused_ontology("cleabc1my012ioqvu5anyaabc") """ - endpoint = self.rest_endpoint + "/ontologies/" + urllib.parse.quote( - ontology_id) + endpoint = ( + self.rest_endpoint + + "/ontologies/" + + urllib.parse.quote(ontology_id) + ) response = self._connection.delete(endpoint) if response.status_code != requests.codes.no_content: raise labelbox.exceptions.LabelboxError( - "Failed to delete the ontology, message: " + - str(response.json()['message'])) + "Failed to delete the ontology, message: " + + str(response.json()["message"]) + ) - def update_feature_schema_title(self, feature_schema_id: str, - title: str) -> FeatureSchema: + def update_feature_schema_title( + self, feature_schema_id: str, title: str + ) -> FeatureSchema: """ Updates a title of a feature schema Args: @@ -1381,16 +1519,21 @@ def update_feature_schema_title(self, feature_schema_id: str, >>> client.update_feature_schema_title("cleabc1my012ioqvu5anyaabc", "New Title") """ - endpoint = self.rest_endpoint + "/feature-schemas/" + urllib.parse.quote( - feature_schema_id) + '/definition' + endpoint = ( + self.rest_endpoint + + "/feature-schemas/" + + urllib.parse.quote(feature_schema_id) + + "/definition" + ) response = self._connection.patch(endpoint, json={"title": title}) if response.status_code == requests.codes.ok: return self.get_feature_schema(feature_schema_id) else: raise labelbox.exceptions.LabelboxError( - "Failed to update the feature schema, message: " + - str(response.json()['message'])) + "Failed to update the feature schema, message: " + + str(response.json()["message"]) + ) def upsert_feature_schema(self, feature_schema: Dict) -> FeatureSchema: """ @@ -1408,23 +1551,29 @@ def upsert_feature_schema(self, feature_schema: Dict) -> FeatureSchema: >>> client.upsert_feature_schema(tool.asdict()) """ - feature_schema_id = feature_schema.get( - "featureSchemaId") or "new_feature_schema_id" - endpoint = self.rest_endpoint + "/feature-schemas/" + urllib.parse.quote( - feature_schema_id) + feature_schema_id = ( + feature_schema.get("featureSchemaId") or "new_feature_schema_id" + ) + endpoint = ( + self.rest_endpoint + + "/feature-schemas/" + + urllib.parse.quote(feature_schema_id) + ) response = self._connection.put( - endpoint, json={"normalized": json.dumps(feature_schema)}) + endpoint, json={"normalized": json.dumps(feature_schema)} + ) if response.status_code == requests.codes.ok: - return self.get_feature_schema(response.json()['schemaId']) + return self.get_feature_schema(response.json()["schemaId"]) else: raise labelbox.exceptions.LabelboxError( - "Failed to upsert the feature schema, message: " + - str(response.json()['message'])) + "Failed to upsert the feature schema, message: " + + str(response.json()["message"]) + ) - def insert_feature_schema_into_ontology(self, feature_schema_id: str, - ontology_id: str, - position: int) -> None: + def insert_feature_schema_into_ontology( + self, feature_schema_id: str, ontology_id: str, position: int + ) -> None: """ Inserts a feature schema into an ontology. If the feature schema is already in the ontology, it will be moved to the new position. @@ -1436,14 +1585,19 @@ def insert_feature_schema_into_ontology(self, feature_schema_id: str, >>> client.insert_feature_schema_into_ontology("cleabc1my012ioqvu5anyaabc", "clefdvwl7abcgefgu3lyvcde", 2) """ - endpoint = self.rest_endpoint + '/ontologies/' + urllib.parse.quote( - ontology_id) + "/feature-schemas/" + urllib.parse.quote( - feature_schema_id) + endpoint = ( + self.rest_endpoint + + "/ontologies/" + + urllib.parse.quote(ontology_id) + + "/feature-schemas/" + + urllib.parse.quote(feature_schema_id) + ) response = self._connection.post(endpoint, json={"position": position}) if response.status_code != requests.codes.created: raise labelbox.exceptions.LabelboxError( "Failed to insert the feature schema into the ontology, message: " - + str(response.json()['message'])) + + str(response.json()["message"]) + ) def get_unused_ontologies(self, after: str = None) -> List[str]: """ @@ -1466,8 +1620,9 @@ def get_unused_ontologies(self, after: str = None) -> List[str]: return response.json() else: raise labelbox.exceptions.LabelboxError( - "Failed to get unused ontologies, message: " + - str(response.json()['message'])) + "Failed to get unused ontologies, message: " + + str(response.json()["message"]) + ) def get_unused_feature_schemas(self, after: str = None) -> List[str]: """ @@ -1490,14 +1645,17 @@ def get_unused_feature_schemas(self, after: str = None) -> List[str]: return response.json() else: raise labelbox.exceptions.LabelboxError( - "Failed to get unused feature schemas, message: " + - str(response.json()['message'])) + "Failed to get unused feature schemas, message: " + + str(response.json()["message"]) + ) - def create_ontology(self, - name, - normalized, - media_type: MediaType = None, - ontology_kind: OntologyKind = None) -> Ontology: + def create_ontology( + self, + name, + normalized, + media_type: MediaType = None, + ontology_kind: OntologyKind = None, + ) -> Ontology: """ Creates an ontology from normalized data >>> normalized = {"tools" : [{'tool': 'polygon', 'name': 'cat', 'color': 'black'}], "classifications" : []} @@ -1515,7 +1673,7 @@ def create_ontology(self, name (str): Name of the ontology normalized (dict): A normalized ontology payload. See above for details. media_type (MediaType or None): Media type of a new ontology - ontology_kind (OntologyKind or None): set to OntologyKind.ModelEvaluation if the ontology is for chat evaluation or + ontology_kind (OntologyKind or None): set to OntologyKind.ModelEvaluation if the ontology is for chat evaluation or OntologyKind.ResponseCreation if ontology is for response creation, leave as None otherwise. Returns: @@ -1533,9 +1691,11 @@ def create_ontology(self, if ontology_kind and OntologyKind.is_supported(ontology_kind): media_type = OntologyKind.evaluate_ontology_kind_with_media_type( - ontology_kind, media_type) + ontology_kind, media_type + ) editor_task_type_value = EditorTaskTypeMapper.to_editor_task_type( - ontology_kind, media_type).value + ontology_kind, media_type + ).value elif ontology_kind: raise OntologyKind.get_ontology_kind_validation_error(ontology_kind) else: @@ -1545,17 +1705,17 @@ def create_ontology(self, upsertOntology(data: $data){ %s } } """ % query.results_query_part(Entity.Ontology) params = { - 'data': { - 'name': name, - 'normalized': json.dumps(normalized), - 'mediaType': media_type_value + "data": { + "name": name, + "normalized": json.dumps(normalized), + "mediaType": media_type_value, } } if editor_task_type_value: - params['data']['editorTaskType'] = editor_task_type_value + params["data"]["editorTaskType"] = editor_task_type_value res = self.execute(query_str, params) - return Entity.Ontology(self, res['upsertOntology']) + return Entity.Ontology(self, res["upsertOntology"]) def create_feature_schema(self, normalized): """ @@ -1592,15 +1752,15 @@ def create_feature_schema(self, normalized): upsertRootSchemaNode(data: $data){ %s } } """ % query.results_query_part(Entity.FeatureSchema) normalized = {k: v for k, v in normalized.items() if v} - params = {'data': {'normalized': json.dumps(normalized)}} - res = self.execute(query_str, params)['upsertRootSchemaNode'] + params = {"data": {"normalized": json.dumps(normalized)}} + res = self.execute(query_str, params)["upsertRootSchemaNode"] # Technically we are querying for a Schema Node. # But the features are the same so we just grab the feature schema id - res['id'] = res['normalized']['featureSchemaId'] + res["id"] = res["normalized"]["featureSchemaId"] return Entity.FeatureSchema(self, res) def get_model_run(self, model_run_id: str) -> ModelRun: - """ Gets a single ModelRun with the given ID. + """Gets a single ModelRun with the given ID. >>> model_run = client.get_model_run("") @@ -1612,9 +1772,10 @@ def get_model_run(self, model_run_id: str) -> ModelRun: return self._get_single(Entity.ModelRun, model_run_id) def assign_global_keys_to_data_rows( - self, - global_key_to_data_row_inputs: List[Dict[str, str]], - timeout_seconds=60) -> Dict[str, Union[str, List[Any]]]: + self, + global_key_to_data_row_inputs: List[Dict[str, str]], + timeout_seconds=60, + ) -> Dict[str, Union[str, List[Any]]]: """ Assigns global keys to data rows. @@ -1645,21 +1806,29 @@ def assign_global_keys_to_data_rows( [{'data_row_id': 'cl7tpjzw30031ka6g4evqdfoy', 'global_key': 'gk"', 'error': 'Invalid global key'}] """ - def _format_successful_rows(rows: Dict[str, str], - sanitized: bool) -> List[Dict[str, str]]: - return [{ - 'data_row_id': r['dataRowId'], - 'global_key': r['globalKey'], - 'sanitized': sanitized - } for r in rows] + def _format_successful_rows( + rows: Dict[str, str], sanitized: bool + ) -> List[Dict[str, str]]: + return [ + { + "data_row_id": r["dataRowId"], + "global_key": r["globalKey"], + "sanitized": sanitized, + } + for r in rows + ] - def _format_failed_rows(rows: Dict[str, str], - error_msg: str) -> List[Dict[str, str]]: - return [{ - 'data_row_id': r['dataRowId'], - 'global_key': r['globalKey'], - 'error': error_msg - } for r in rows] + def _format_failed_rows( + rows: Dict[str, str], error_msg: str + ) -> List[Dict[str, str]]: + return [ + { + "data_row_id": r["dataRowId"], + "global_key": r["globalKey"], + "error": error_msg, + } + for r in rows + ] # Validate input dict validation_errors = [] @@ -1679,9 +1848,10 @@ def _format_failed_rows(rows: Dict[str, str], } """ params = { - 'globalKeyDataRowLinks': [{ - utils.camel_case(key): value for key, value in input.items() - } for input in global_key_to_data_row_inputs] + "globalKeyDataRowLinks": [ + {utils.camel_case(key): value for key, value in input.items()} + for input in global_key_to_data_row_inputs + ] } assign_global_keys_to_data_rows_job = self.execute(query_str, params) @@ -1709,9 +1879,9 @@ def _format_failed_rows(rows: Dict[str, str], }}} """ result_params = { - "jobId": - assign_global_keys_to_data_rows_job["assignGlobalKeysToDataRows" - ]["jobId"] + "jobId": assign_global_keys_to_data_rows_job[ + "assignGlobalKeysToDataRows" + ]["jobId"] } # Poll job status until finished, then retrieve results @@ -1719,27 +1889,36 @@ def _format_failed_rows(rows: Dict[str, str], start_time = time.time() while True: res = self.execute(result_query_str, result_params) - if res["assignGlobalKeysToDataRowsResult"][ - "jobStatus"] == "COMPLETE": + if ( + res["assignGlobalKeysToDataRowsResult"]["jobStatus"] + == "COMPLETE" + ): results, errors = [], [] - res = res['assignGlobalKeysToDataRowsResult']['data'] + res = res["assignGlobalKeysToDataRowsResult"]["data"] # Successful assignments results.extend( - _format_successful_rows(rows=res['sanitizedAssignments'], - sanitized=True)) + _format_successful_rows( + rows=res["sanitizedAssignments"], sanitized=True + ) + ) results.extend( - _format_successful_rows(rows=res['unmodifiedAssignments'], - sanitized=False)) + _format_successful_rows( + rows=res["unmodifiedAssignments"], sanitized=False + ) + ) # Failed assignments errors.extend( _format_failed_rows( - rows=res['invalidGlobalKeyAssignments'], - error_msg= - "Invalid assignment. Either DataRow does not exist, or globalKey is invalid" - )) + rows=res["invalidGlobalKeyAssignments"], + error_msg="Invalid assignment. Either DataRow does not exist, or globalKey is invalid", + ) + ) errors.extend( - _format_failed_rows(rows=res['accessDeniedAssignments'], - error_msg="Access denied to Data Row")) + _format_failed_rows( + rows=res["accessDeniedAssignments"], + error_msg="Access denied to Data Row", + ) + ) if not errors: status = CollectionJobStatus.SUCCESS.value @@ -1758,10 +1937,12 @@ def _format_failed_rows(rows: Dict[str, str], "results": results, "errors": errors, } - elif res["assignGlobalKeysToDataRowsResult"][ - "jobStatus"] == "FAILED": + elif ( + res["assignGlobalKeysToDataRowsResult"]["jobStatus"] == "FAILED" + ): raise labelbox.exceptions.LabelboxError( - "Job assign_global_keys_to_data_rows failed.") + "Job assign_global_keys_to_data_rows failed." + ) current_time = time.time() if current_time - start_time > timeout_seconds: raise labelbox.exceptions.TimeoutError( @@ -1770,9 +1951,8 @@ def _format_failed_rows(rows: Dict[str, str], time.sleep(sleep_time) def get_data_row_ids_for_global_keys( - self, - global_keys: List[str], - timeout_seconds=60) -> Dict[str, Union[str, List[Any]]]: + self, global_keys: List[str], timeout_seconds=60 + ) -> Dict[str, Union[str, List[Any]]]: """ Gets data row ids for a list of global keys. @@ -1805,9 +1985,10 @@ def get_data_row_ids_for_global_keys( [{'global_key': 'asdf', 'error': 'Data Row not found'}] """ - def _format_failed_rows(rows: List[str], - error_msg: str) -> List[Dict[str, str]]: - return [{'global_key': r, 'error': error_msg} for r in rows] + def _format_failed_rows( + rows: List[str], error_msg: str + ) -> List[Dict[str, str]]: + return [{"global_key": r, "error": error_msg} for r in rows] # Start get data rows for global keys job query_str = """query getDataRowsForGlobalKeysPyApi($globalKeys: [ID!]!) { @@ -1825,8 +2006,9 @@ def _format_failed_rows(rows: List[str], } jobStatus}} """ result_params = { - "jobId": - data_rows_for_global_keys_job["dataRowsForGlobalKeys"]["jobId"] + "jobId": data_rows_for_global_keys_job["dataRowsForGlobalKeys"][ + "jobId" + ] } # Poll job status until finished, then retrieve results @@ -1834,20 +2016,25 @@ def _format_failed_rows(rows: List[str], start_time = time.time() while True: res = self.execute(result_query_str, result_params) - if res["dataRowsForGlobalKeysResult"]['jobStatus'] == "COMPLETE": - data = res["dataRowsForGlobalKeysResult"]['data'] + if res["dataRowsForGlobalKeysResult"]["jobStatus"] == "COMPLETE": + data = res["dataRowsForGlobalKeysResult"]["data"] results, errors = [], [] - results.extend([row['id'] for row in data['fetchedDataRows']]) + results.extend([row["id"] for row in data["fetchedDataRows"]]) errors.extend( - _format_failed_rows(data['notFoundGlobalKeys'], - "Data Row not found")) + _format_failed_rows( + data["notFoundGlobalKeys"], "Data Row not found" + ) + ) errors.extend( - _format_failed_rows(data['accessDeniedGlobalKeys'], - "Access denied to Data Row")) + _format_failed_rows( + data["accessDeniedGlobalKeys"], + "Access denied to Data Row", + ) + ) # Invalid results may contain empty string, so we must filter # them prior to checking for PARTIAL_SUCCESS - filtered_results = list(filter(lambda r: r != '', results)) + filtered_results = list(filter(lambda r: r != "", results)) if not errors: status = CollectionJobStatus.SUCCESS.value elif errors and len(filtered_results) > 0: @@ -1861,9 +2048,10 @@ def _format_failed_rows(rows: List[str], ) return {"status": status, "results": results, "errors": errors} - elif res["dataRowsForGlobalKeysResult"]['jobStatus'] == "FAILED": + elif res["dataRowsForGlobalKeysResult"]["jobStatus"] == "FAILED": raise labelbox.exceptions.LabelboxError( - "Job dataRowsForGlobalKeys failed.") + "Job dataRowsForGlobalKeys failed." + ) current_time = time.time() if current_time - start_time > timeout_seconds: raise labelbox.exceptions.TimeoutError( @@ -1872,9 +2060,8 @@ def _format_failed_rows(rows: List[str], time.sleep(sleep_time) def clear_global_keys( - self, - global_keys: List[str], - timeout_seconds=60) -> Dict[str, Union[str, List[Any]]]: + self, global_keys: List[str], timeout_seconds=60 + ) -> Dict[str, Union[str, List[Any]]]: """ Clears global keys for the data rows tha correspond to the global keys provided. @@ -1900,9 +2087,10 @@ def clear_global_keys( [{'global_key': 'notfoundkey', 'error': 'Failed to find data row matching provided global key'}] """ - def _format_failed_rows(rows: List[str], - error_msg: str) -> List[Dict[str, str]]: - return [{'global_key': r, 'error': error_msg} for r in rows] + def _format_failed_rows( + rows: List[str], error_msg: str + ) -> List[Dict[str, str]]: + return [{"global_key": r, "error": error_msg} for r in rows] # Start get data rows for global keys job query_str = """mutation clearGlobalKeysPyApi($globalKeys: [ID!]!) { @@ -1928,22 +2116,28 @@ def _format_failed_rows(rows: List[str], start_time = time.time() while True: res = self.execute(result_query_str, result_params) - if res["clearGlobalKeysResult"]['jobStatus'] == "COMPLETE": - data = res["clearGlobalKeysResult"]['data'] + if res["clearGlobalKeysResult"]["jobStatus"] == "COMPLETE": + data = res["clearGlobalKeysResult"]["data"] results, errors = [], [] - results.extend(data['clearedGlobalKeys']) + results.extend(data["clearedGlobalKeys"]) errors.extend( - _format_failed_rows(data['failedToClearGlobalKeys'], - "Clearing global key failed")) + _format_failed_rows( + data["failedToClearGlobalKeys"], + "Clearing global key failed", + ) + ) errors.extend( _format_failed_rows( - data['notFoundGlobalKeys'], - "Failed to find data row matching provided global key")) + data["notFoundGlobalKeys"], + "Failed to find data row matching provided global key", + ) + ) errors.extend( _format_failed_rows( - data['accessDeniedGlobalKeys'], - "Denied access to modify data row matching provided global key" - )) + data["accessDeniedGlobalKeys"], + "Denied access to modify data row matching provided global key", + ) + ) if not errors: status = CollectionJobStatus.SUCCESS.value @@ -1958,13 +2152,15 @@ def _format_failed_rows(rows: List[str], ) return {"status": status, "results": results, "errors": errors} - elif res["clearGlobalKeysResult"]['jobStatus'] == "FAILED": + elif res["clearGlobalKeysResult"]["jobStatus"] == "FAILED": raise labelbox.exceptions.LabelboxError( - "Job clearGlobalKeys failed.") + "Job clearGlobalKeys failed." + ) current_time = time.time() if current_time - start_time > timeout_seconds: raise labelbox.exceptions.TimeoutError( - "Timed out waiting for clear_global_keys job to complete.") + "Timed out waiting for clear_global_keys job to complete." + ) time.sleep(sleep_time) def get_catalog(self) -> Catalog: @@ -1990,11 +2186,12 @@ def get_catalog_slice(self, slice_id) -> CatalogSlice: } } """ - res = self.execute(query_str, {'id': slice_id}) - return Entity.CatalogSlice(self, res['getSavedQuery']) + res = self.execute(query_str, {"id": slice_id}) + return Entity.CatalogSlice(self, res["getSavedQuery"]) - def is_feature_schema_archived(self, ontology_id: str, - feature_schema_id: str) -> bool: + def is_feature_schema_archived( + self, ontology_id: str, feature_schema_id: str + ) -> bool: """ Returns true if a feature schema is archived in the specified ontology, returns false otherwise. @@ -2005,33 +2202,39 @@ def is_feature_schema_archived(self, ontology_id: str, bool """ - ontology_endpoint = self.rest_endpoint + "/ontologies/" + urllib.parse.quote( - ontology_id) + ontology_endpoint = ( + self.rest_endpoint + + "/ontologies/" + + urllib.parse.quote(ontology_id) + ) response = self._connection.get(ontology_endpoint) if response.status_code == requests.codes.ok: - feature_schema_nodes = response.json()['featureSchemaNodes'] - tools = feature_schema_nodes['tools'] - classifications = feature_schema_nodes['classifications'] - relationships = feature_schema_nodes['relationships'] + feature_schema_nodes = response.json()["featureSchemaNodes"] + tools = feature_schema_nodes["tools"] + classifications = feature_schema_nodes["classifications"] + relationships = feature_schema_nodes["relationships"] feature_schema_node_list = tools + classifications + relationships filtered_feature_schema_nodes = [ feature_schema_node for feature_schema_node in feature_schema_node_list - if feature_schema_node['featureSchemaId'] == feature_schema_id + if feature_schema_node["featureSchemaId"] == feature_schema_id ] if filtered_feature_schema_nodes: - return bool(filtered_feature_schema_nodes[0]['archived']) + return bool(filtered_feature_schema_nodes[0]["archived"]) else: raise labelbox.exceptions.LabelboxError( - "The specified feature schema was not in the ontology.") + "The specified feature schema was not in the ontology." + ) elif response.status_code == 404: raise labelbox.exceptions.ResourceNotFoundError( - Ontology, ontology_id) + Ontology, ontology_id + ) else: raise labelbox.exceptions.LabelboxError( - "Failed to get the feature schema archived status.") + "Failed to get the feature schema archived status." + ) def get_model_slice(self, slice_id) -> ModelSlice: """ @@ -2057,13 +2260,14 @@ def get_model_slice(self, slice_id) -> ModelSlice: res = self.execute(query_str, {"id": slice_id}) if res is None or res["getSavedQuery"] is None: raise labelbox.exceptions.ResourceNotFoundError( - ModelSlice, slice_id) + ModelSlice, slice_id + ) return Entity.ModelSlice(self, res["getSavedQuery"]) def delete_feature_schema_from_ontology( - self, ontology_id: str, - feature_schema_id: str) -> DeleteFeatureFromOntologyResult: + self, ontology_id: str, feature_schema_id: str + ) -> DeleteFeatureFromOntologyResult: """ Deletes or archives a feature schema from an ontology. If the feature schema is a root level node with associated labels, it will be archived. @@ -2080,31 +2284,38 @@ def delete_feature_schema_from_ontology( Example: >>> client.delete_feature_schema_from_ontology(, ) """ - ontology_endpoint = self.rest_endpoint + "/ontologies/" + urllib.parse.quote( - ontology_id) + "/feature-schemas/" + urllib.parse.quote( - feature_schema_id) + ontology_endpoint = ( + self.rest_endpoint + + "/ontologies/" + + urllib.parse.quote(ontology_id) + + "/feature-schemas/" + + urllib.parse.quote(feature_schema_id) + ) response = self._connection.delete(ontology_endpoint) if response.status_code == requests.codes.ok: response_json = response.json() - if response_json['archived'] == True: + if response_json["archived"] == True: logger.info( - 'Feature schema was archived from the ontology because it had associated labels.' + "Feature schema was archived from the ontology because it had associated labels." ) - elif response_json['deleted'] == True: + elif response_json["deleted"] == True: logger.info( - 'Feature schema was successfully removed from the ontology') + "Feature schema was successfully removed from the ontology" + ) result = DeleteFeatureFromOntologyResult() - result.archived = bool(response_json['archived']) - result.deleted = bool(response_json['deleted']) + result.archived = bool(response_json["archived"]) + result.deleted = bool(response_json["deleted"]) return result else: raise labelbox.exceptions.LabelboxError( - "Failed to remove feature schema from ontology, message: " + - str(response.json()['message'])) + "Failed to remove feature schema from ontology, message: " + + str(response.json()["message"]) + ) - def unarchive_feature_schema_node(self, ontology_id: str, - root_feature_schema_id: str) -> None: + def unarchive_feature_schema_node( + self, ontology_id: str, root_feature_schema_id: str + ) -> None: """ Unarchives a feature schema node in an ontology. Only root level feature schema nodes can be unarchived. @@ -2114,18 +2325,25 @@ def unarchive_feature_schema_node(self, ontology_id: str, Returns: None """ - ontology_endpoint = self.rest_endpoint + "/ontologies/" + urllib.parse.quote( - ontology_id) + '/feature-schemas/' + urllib.parse.quote( - root_feature_schema_id) + '/unarchive' + ontology_endpoint = ( + self.rest_endpoint + + "/ontologies/" + + urllib.parse.quote(ontology_id) + + "/feature-schemas/" + + urllib.parse.quote(root_feature_schema_id) + + "/unarchive" + ) response = self._connection.patch(ontology_endpoint) if response.status_code == requests.codes.ok: - if not bool(response.json()['unarchived']): + if not bool(response.json()["unarchived"]): raise labelbox.exceptions.LabelboxError( - "Failed unarchive the feature schema.") + "Failed unarchive the feature schema." + ) else: raise labelbox.exceptions.LabelboxError( "Failed unarchive the feature schema node, message: ", - response.text) + response.text, + ) def get_batch(self, project_id: str, batch_id: str) -> Entity.Batch: # obtain batch entity to return @@ -2138,24 +2356,28 @@ def get_batch(self, project_id: str, batch_id: str) -> Entity.Batch: } } } - """ % ("getProjectBatchPyApi", - query.results_query_part(Entity.Batch)) + """ % ( + "getProjectBatchPyApi", + query.results_query_part(Entity.Batch), + ) batch = self.execute( - get_batch_str, { - "projectId": project_id, - "batchId": batch_id - }, + get_batch_str, + {"projectId": project_id, "batchId": batch_id}, timeout=180.0, - experimental=True)["project"]["batches"]["nodes"][0] + experimental=True, + )["project"]["batches"]["nodes"][0] return Entity.Batch(self, project_id, batch) - def send_to_annotate_from_catalog(self, destination_project_id: str, - task_queue_id: Optional[str], - batch_name: str, - data_rows: Union[DataRowIds, GlobalKeys], - params: Dict[str, Any]): + def send_to_annotate_from_catalog( + self, + destination_project_id: str, + task_queue_id: Optional[str], + batch_name: str, + data_rows: Union[DataRowIds, GlobalKeys], + params: Dict[str, Any], + ): """ Sends data rows from catalog to a specified project for annotation. @@ -2196,56 +2418,55 @@ def send_to_annotate_from_catalog(self, destination_project_id: str, """ destination_task_queue = build_destination_task_queue_input( - task_queue_id) + task_queue_id + ) data_rows_query = self.build_catalog_query(data_rows) - predictions_input = build_predictions_input( - validated_params.predictions_ontology_mapping, - validated_params.source_model_run_id - ) if validated_params.source_model_run_id else None - - annotations_input = build_annotations_input( - validated_params.annotations_ontology_mapping, validated_params. - source_project_id) if validated_params.source_project_id else None + predictions_input = ( + build_predictions_input( + validated_params.predictions_ontology_mapping, + validated_params.source_model_run_id, + ) + if validated_params.source_model_run_id + else None + ) + + annotations_input = ( + build_annotations_input( + validated_params.annotations_ontology_mapping, + validated_params.source_project_id, + ) + if validated_params.source_project_id + else None + ) res = self.execute( - mutation_str, { + mutation_str, + { "input": { - "destinationProjectId": - destination_project_id, + "destinationProjectId": destination_project_id, "batchInput": { "batchName": batch_name, - "batchPriority": validated_params.batch_priority - }, - "destinationTaskQueue": - destination_task_queue, - "excludeDataRowsInProject": - validated_params.exclude_data_rows_in_project, - "annotationsInput": - annotations_input, - "predictionsInput": - predictions_input, - "conflictLabelsResolutionStrategy": - validated_params.override_existing_annotations_rule, - "searchQuery": { - "scope": None, - "query": [data_rows_query] + "batchPriority": validated_params.batch_priority, }, + "destinationTaskQueue": destination_task_queue, + "excludeDataRowsInProject": validated_params.exclude_data_rows_in_project, + "annotationsInput": annotations_input, + "predictionsInput": predictions_input, + "conflictLabelsResolutionStrategy": validated_params.override_existing_annotations_rule, + "searchQuery": {"scope": None, "query": [data_rows_query]}, "ordering": { "type": "RANDOM", - "random": { - "seed": random.randint(0, 10000) - }, - "sorting": None + "random": {"seed": random.randint(0, 10000)}, + "sorting": None, }, - "sorting": - None, - "limit": - None + "sorting": None, + "limit": None, } - })['sendToAnnotateFromCatalog'] + }, + )["sendToAnnotateFromCatalog"] - return Entity.Task.get_task(self, res['taskId']) + return Entity.Task.get_task(self, res["taskId"]) @staticmethod def build_catalog_query(data_rows: Union[DataRowIds, GlobalKeys]): @@ -2262,13 +2483,13 @@ def build_catalog_query(data_rows: Union[DataRowIds, GlobalKeys]): data_rows_query = { "type": "data_row_id", "operator": "is", - "ids": list(data_rows) + "ids": list(data_rows), } elif isinstance(data_rows, GlobalKeys): data_rows_query = { "type": "global_key", "operator": "is", - "ids": list(data_rows) + "ids": list(data_rows), } else: raise ValueError( @@ -2276,9 +2497,12 @@ def build_catalog_query(data_rows: Union[DataRowIds, GlobalKeys]): ) return data_rows_query - def run_foundry_app(self, model_run_name: str, data_rows: Union[DataRowIds, - GlobalKeys], - app_id: str) -> Task: + def run_foundry_app( + self, + model_run_name: str, + data_rows: Union[DataRowIds, GlobalKeys], + app_id: str, + ) -> Task: """ Run a foundry app @@ -2345,11 +2569,13 @@ def get_embedding_by_name(self, name: str) -> Embedding: for e in embeddings: if e.name == name: return e - raise labelbox.exceptions.ResourceNotFoundError(Embedding, - dict(name=name)) + raise labelbox.exceptions.ResourceNotFoundError( + Embedding, dict(name=name) + ) - def upsert_label_feedback(self, label_id: str, feedback: str, - scores: Dict[str, float]) -> List[LabelScore]: + def upsert_label_feedback( + self, label_id: str, feedback: str, scores: Dict[str, float] + ) -> List[LabelScore]: """ Submits the label feedback which is a free-form text and numeric label scores. @@ -2385,15 +2611,14 @@ def upsert_label_feedback(self, label_id: str, feedback: str, } } """ - res = self.execute(mutation_str, { - "labelId": label_id, - "feedback": feedback, - "scores": scores - }) + res = self.execute( + mutation_str, + {"labelId": label_id, "feedback": feedback, "scores": scores}, + ) scores_raw = res["upsertAutoQaLabelFeedback"]["scores"] return [ - labelbox.LabelScore(name=x['name'], score=x['score']) + labelbox.LabelScore(name=x["name"], score=x["score"]) for x in scores_raw ] @@ -2406,12 +2631,12 @@ def get_labeling_service_dashboards( Optional parameters: search_query: A list of search filters representing the search - + NOTE: - Retrieves all projects for the organization or as filtered by the search query - INCLUDING those not requesting labeling services - Sorted by project created date in ascending order. - + Examples: Retrieves all labeling service dashboards for a given workspace id: >>> workspace_filter = WorkspaceFilter( @@ -2442,7 +2667,7 @@ def get_task_by_id(self, task_id: str) -> Union[Task, DataUpsertTask]: Returns: Task or DataUpsertTask - + Throws: ResourceNotFoundError: If the task does not exist. @@ -2471,9 +2696,10 @@ def get_task_by_id(self, task_id: str) -> Union[Task, DataUpsertTask]: data = result.get("user", {}).get("createdTasks", []) if not data: raise labelbox.exceptions.ResourceNotFoundError( - message=f"The task {task_id} does not exist.") + message=f"The task {task_id} does not exist." + ) task_data = data[0] - if task_data["type"].lower() == 'adv-upsert-data-rows': + if task_data["type"].lower() == "adv-upsert-data-rows": task = DataUpsertTask(self, task_data) else: task = Task(self, task_data) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/__init__.py index 5b51814ec..7908bc242 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/__init__.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/__init__.py @@ -64,4 +64,11 @@ from .llm_prompt_response.prompt import PromptText from .llm_prompt_response.prompt import PromptClassificationAnnotation -from .mmc import MessageInfo, OrderedMessageInfo, MessageSingleSelectionTask, MessageMultiSelectionTask, MessageRankingTask, MessageEvaluationTaskAnnotation +from .mmc import ( + MessageInfo, + OrderedMessageInfo, + MessageSingleSelectionTask, + MessageMultiSelectionTask, + MessageRankingTask, + MessageEvaluationTaskAnnotation, +) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/annotation.py b/libs/labelbox/src/labelbox/data/annotation_types/annotation.py index 8a718751a..2c2f110a0 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/annotation.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/annotation.py @@ -5,7 +5,9 @@ from labelbox.data.mixins import ConfidenceMixin, CustomMetricsMixin -from labelbox.data.annotation_types.classification.classification import ClassificationAnnotation +from labelbox.data.annotation_types.classification.classification import ( + ClassificationAnnotation, +) from .ner import DocumentEntity, TextEntity, ConversationEntity from typing import Optional diff --git a/libs/labelbox/src/labelbox/data/annotation_types/base_annotation.py b/libs/labelbox/src/labelbox/data/annotation_types/base_annotation.py index 27e66c063..ee9bf751b 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/base_annotation.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/base_annotation.py @@ -7,11 +7,11 @@ class BaseAnnotation(FeatureSchema, abc.ABC): - """ Base annotation class. Shouldn't be directly instantiated - """ + """Base annotation class. Shouldn't be directly instantiated""" + _uuid: Optional[UUID] = PrivateAttr() extra: Dict[str, Any] = {} - + model_config = ConfigDict(extra="allow") def __init__(self, **data): diff --git a/libs/labelbox/src/labelbox/data/annotation_types/classification/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/classification/__init__.py index 5bb098730..a814336e4 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/classification/__init__.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/classification/__init__.py @@ -1,2 +1 @@ -from .classification import (Checklist, ClassificationAnswer, Radio, - Text) +from .classification import Checklist, ClassificationAnswer, Radio, Text diff --git a/libs/labelbox/src/labelbox/data/annotation_types/classification/classification.py b/libs/labelbox/src/labelbox/data/annotation_types/classification/classification.py index 23c4c848a..d6a6448dd 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/classification/classification.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/classification/classification.py @@ -18,40 +18,45 @@ class ClassificationAnswer(FeatureSchema, ConfidenceMixin, CustomMetricsMixin): So unlike object annotations, classification annotations track keyframes at a classification answer level. """ + extra: Dict[str, Any] = {} keyframe: Optional[bool] = None - classifications: Optional[List['ClassificationAnnotation']] = None + classifications: Optional[List["ClassificationAnnotation"]] = None class Radio(ConfidenceMixin, CustomMetricsMixin, BaseModel): - """ A classification with only one selected option allowed + """A classification with only one selected option allowed >>> Radio(answer = ClassificationAnswer(name = "dog")) """ + answer: ClassificationAnswer class Checklist(ConfidenceMixin, BaseModel): - """ A classification with many selected options allowed + """A classification with many selected options allowed >>> Checklist(answer = [ClassificationAnswer(name = "cloudy")]) """ + answer: List[ClassificationAnswer] class Text(ConfidenceMixin, CustomMetricsMixin, BaseModel): - """ Free form text + """Free form text >>> Text(answer = "some text answer") """ + answer: str -class ClassificationAnnotation(BaseAnnotation, ConfidenceMixin, - CustomMetricsMixin): +class ClassificationAnnotation( + BaseAnnotation, ConfidenceMixin, CustomMetricsMixin +): """Classification annotations (non localized) >>> ClassificationAnnotation( @@ -65,7 +70,7 @@ class ClassificationAnnotation(BaseAnnotation, ConfidenceMixin, feature_schema_id (Optional[Cuid]) value (Union[Text, Checklist, Radio]) extra (Dict[str, Any]) - """ + """ value: Union[Text, Checklist, Radio] message_id: Optional[str] = None diff --git a/libs/labelbox/src/labelbox/data/annotation_types/collection.py b/libs/labelbox/src/labelbox/data/annotation_types/collection.py index 04c78a583..d90204309 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/collection.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/collection.py @@ -17,7 +17,7 @@ class LabelGenerator(PrefetchGenerator): """ - A container for interacting with a large collection of labels. + A container for interacting with a large collection of labels. For a small number of labels, just use a list of Label objects. """ @@ -26,21 +26,23 @@ def __init__(self, data: Generator[Label, None, None], *args, **kwargs): super().__init__(data, *args, **kwargs) def assign_feature_schema_ids( - self, - ontology_builder: "ontology.OntologyBuilder") -> "LabelGenerator": - + self, ontology_builder: "ontology.OntologyBuilder" + ) -> "LabelGenerator": def _assign_ids(label: Label): label.assign_feature_schema_ids(ontology_builder) return label - warnings.warn("This method is deprecated and will be " - "removed in a future release. Feature schema ids" - " are no longer required for importing.") - self._fns['assign_feature_schema_ids'] = _assign_ids + warnings.warn( + "This method is deprecated and will be " + "removed in a future release. Feature schema ids" + " are no longer required for importing." + ) + self._fns["assign_feature_schema_ids"] = _assign_ids return self - def add_url_to_data(self, signer: Callable[[bytes], - str]) -> "LabelGenerator": + def add_url_to_data( + self, signer: Callable[[bytes], str] + ) -> "LabelGenerator": """ Creates signed urls for the data Only uploads url if one doesn't already exist. @@ -55,11 +57,12 @@ def _add_url_to_data(label: Label): label.add_url_to_data(signer) return label - self._fns['add_url_to_data'] = _add_url_to_data + self._fns["add_url_to_data"] = _add_url_to_data return self - def add_to_dataset(self, dataset: "Entity.Dataset", - signer: Callable[[bytes], str]) -> "LabelGenerator": + def add_to_dataset( + self, dataset: "Entity.Dataset", signer: Callable[[bytes], str] + ) -> "LabelGenerator": """ Creates data rows from each labels data object and attaches the data to the given dataset. Updates the label's data object to have the same external_id and uid as the data row. @@ -75,11 +78,12 @@ def _add_to_dataset(label: Label): label.create_data_row(dataset, signer) return label - self._fns['assign_datarow_ids'] = _add_to_dataset + self._fns["assign_datarow_ids"] = _add_to_dataset return self - def add_url_to_masks(self, signer: Callable[[bytes], - str]) -> "LabelGenerator": + def add_url_to_masks( + self, signer: Callable[[bytes], str] + ) -> "LabelGenerator": """ Creates signed urls for all masks in the LabelGenerator. Multiple masks can reference the same MaskData so this makes sure we only upload that url once. @@ -97,11 +101,12 @@ def _add_url_to_masks(label: Label): label.add_url_to_masks(signer) return label - self._fns['add_url_to_masks'] = _add_url_to_masks + self._fns["add_url_to_masks"] = _add_url_to_masks return self - def register_background_fn(self, fn: Callable[[Label], Label], - name: str) -> "LabelGenerator": + def register_background_fn( + self, fn: Callable[[Label], Label], name: str + ) -> "LabelGenerator": """ Allows users to add arbitrary io functions to the generator. These functions will be exectuted in parallel and added to a prefetch queue. diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/data/__init__.py index 99978caac..2522b2741 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/__init__.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/__init__.py @@ -9,4 +9,4 @@ from .video import VideoData from .llm_prompt_response_creation import LlmPromptResponseCreationData from .llm_prompt_creation import LlmPromptCreationData -from .llm_response_creation import LlmResponseCreationData \ No newline at end of file +from .llm_response_creation import LlmResponseCreationData diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/audio.py b/libs/labelbox/src/labelbox/data/annotation_types/data/audio.py index 76be33110..916fca99d 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/audio.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/audio.py @@ -4,4 +4,4 @@ class AudioData(BaseData, _NoCoercionMixin): - class_name: Literal["AudioData"] = "AudioData" \ No newline at end of file + class_name: Literal["AudioData"] = "AudioData" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/base_data.py b/libs/labelbox/src/labelbox/data/annotation_types/data/base_data.py index 2ccda34c3..7d26ba5ca 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/base_data.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/base_data.py @@ -9,6 +9,7 @@ class BaseData(BaseModel, ABC): Base class for objects representing data. This class shouldn't directly be used """ + external_id: Optional[str] = None uid: Optional[str] = None global_key: Optional[str] = None diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/dicom.py b/libs/labelbox/src/labelbox/data/annotation_types/data/dicom.py index 753475c3e..ae4c377dc 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/dicom.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/dicom.py @@ -4,4 +4,4 @@ class DicomData(BaseData, _NoCoercionMixin): - class_name: Literal["DicomData"] = "DicomData" \ No newline at end of file + class_name: Literal["DicomData"] = "DicomData" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/document.py b/libs/labelbox/src/labelbox/data/annotation_types/data/document.py index 5b2610c5b..810a3ed3e 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/document.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/document.py @@ -4,4 +4,4 @@ class DocumentData(BaseData, _NoCoercionMixin): - class_name: Literal["DocumentData"] = "DocumentData" \ No newline at end of file + class_name: Literal["DocumentData"] = "DocumentData" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/generic_data_row_data.py b/libs/labelbox/src/labelbox/data/annotation_types/data/generic_data_row_data.py index 6a73519c1..9bb6a7e0a 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/generic_data_row_data.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/generic_data_row_data.py @@ -6,8 +6,8 @@ class GenericDataRowData(BaseData, _NoCoercionMixin): - """Generic data row data. This is replacing all other DataType passed into Label - """ + """Generic data row data. This is replacing all other DataType passed into Label""" + url: Optional[str] = None class_name: Literal["GenericDataRowData"] = "GenericDataRowData" @@ -17,7 +17,7 @@ def create_url(self, signer: Callable[[bytes], str]) -> Optional[str]: @model_validator(mode="before") @classmethod def validate_one_datarow_key_present(cls, data): - keys = ['external_id', 'global_key', 'uid'] + keys = ["external_id", "global_key", "uid"] count = sum([key in data for key in keys]) if count < 1: diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/html.py b/libs/labelbox/src/labelbox/data/annotation_types/data/html.py index 1820ce467..7a78fcb7b 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/html.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/html.py @@ -4,4 +4,4 @@ class HTMLData(BaseData, _NoCoercionMixin): - class_name: Literal["HTMLData"] = "HTMLData" \ No newline at end of file + class_name: Literal["HTMLData"] = "HTMLData" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_creation.py b/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_creation.py index 4fd788f1a..a1b0450bc 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_creation.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_creation.py @@ -4,4 +4,4 @@ class LlmPromptCreationData(BaseData, _NoCoercionMixin): - class_name: Literal["LlmPromptCreationData"] = "LlmPromptCreationData" \ No newline at end of file + class_name: Literal["LlmPromptCreationData"] = "LlmPromptCreationData" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_response_creation.py b/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_response_creation.py index 2bad75f6d..a8dfce894 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_response_creation.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_response_creation.py @@ -4,5 +4,6 @@ class LlmPromptResponseCreationData(BaseData, _NoCoercionMixin): - class_name: Literal[ - "LlmPromptResponseCreationData"] = "LlmPromptResponseCreationData" \ No newline at end of file + class_name: Literal["LlmPromptResponseCreationData"] = ( + "LlmPromptResponseCreationData" + ) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_response_creation.py b/libs/labelbox/src/labelbox/data/annotation_types/data/llm_response_creation.py index 43c604e34..a8963ed3f 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_response_creation.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/llm_response_creation.py @@ -4,4 +4,4 @@ class LlmResponseCreationData(BaseData, _NoCoercionMixin): - class_name: Literal["LlmResponseCreationData"] = "LlmResponseCreationData" \ No newline at end of file + class_name: Literal["LlmResponseCreationData"] = "LlmResponseCreationData" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/raster.py b/libs/labelbox/src/labelbox/data/annotation_types/data/raster.py index 234d8b136..ba4c6485f 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/raster.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/raster.py @@ -16,21 +16,23 @@ class RasterData(BaseModel, ABC): - """Represents an image or segmentation mask. - """ + """Represents an image or segmentation mask.""" + im_bytes: Optional[bytes] = None file_path: Optional[str] = None url: Optional[str] = None uid: Optional[str] = None global_key: Optional[str] = None - arr: Optional[TypedArray[Literal['uint8']]] = None - + arr: Optional[TypedArray[Literal["uint8"]]] = None + model_config = ConfigDict(extra="forbid") @classmethod - def from_2D_arr(cls, arr: Union[TypedArray[Literal['uint8']], - TypedArray[Literal['int']]], - **kwargs) -> "RasterData": + def from_2D_arr( + cls, + arr: Union[TypedArray[Literal["uint8"]], TypedArray[Literal["int"]]], + **kwargs, + ) -> "RasterData": """Construct from a 2D numpy array Args: @@ -117,11 +119,12 @@ def value(self) -> np.ndarray: raise ValueError("Must set either url, file_path or im_bytes") def set_fetch_fn(self, fn): - object.__setattr__(self, 'fetch_remote', lambda: fn(self)) + object.__setattr__(self, "fetch_remote", lambda: fn(self)) - @retry.Retry(deadline=15., - predicate=retry.if_exception_type(ConnectTimeout, - InternalServerError)) + @retry.Retry( + deadline=15.0, + predicate=retry.if_exception_type(ConnectTimeout, InternalServerError), + ) def fetch_remote(self) -> bytes: """ Method for accessing url. @@ -135,7 +138,7 @@ def fetch_remote(self) -> bytes: response.raise_for_status() return response.content - @retry.Retry(deadline=30.) + @retry.Retry(deadline=30.0) def create_url(self, signer: Callable[[bytes], str]) -> str: """ Utility for creating a url from any of the other image representations. @@ -150,13 +153,14 @@ def create_url(self, signer: Callable[[bytes], str]) -> str: elif self.im_bytes is not None: self.url = signer(self.im_bytes) elif self.file_path is not None: - with open(self.file_path, 'rb') as file: + with open(self.file_path, "rb") as file: self.url = signer(file.read()) elif self.arr is not None: self.url = signer(self.np_to_bytes(self.arr)) else: raise ValueError( - "One of url, im_bytes, file_path, arr must not be None.") + "One of url, im_bytes, file_path, arr must not be None." + ) return self.url @model_validator(mode="after") @@ -167,7 +171,10 @@ def validate_args(self, values): arr = self.arr uid = self.uid global_key = self.global_key - if uid == file_path == im_bytes == url == global_key == None and arr is None: + if ( + uid == file_path == im_bytes == url == global_key == None + and arr is None + ): raise ValueError( "One of `file_path`, `im_bytes`, `url`, `uid`, `global_key` or `arr` required." ) @@ -179,15 +186,18 @@ def validate_args(self, values): elif len(arr.shape) != 3: raise ValueError( "unsupported image format. Must be 3D ([H,W,C])." - f"Use {self.__name__}.from_2D_arr to construct from 2D") + f"Use {self.__name__}.from_2D_arr to construct from 2D" + ) return self def __repr__(self) -> str: - symbol_or_none = lambda data: '...' if data is not None else None - return f"{self.__class__.__name__}(im_bytes={symbol_or_none(self.im_bytes)}," \ - f"file_path={self.file_path}," \ - f"url={self.url}," \ - f"arr={symbol_or_none(self.arr)})" + symbol_or_none = lambda data: "..." if data is not None else None + return ( + f"{self.__class__.__name__}(im_bytes={symbol_or_none(self.im_bytes)}," + f"file_path={self.file_path}," + f"url={self.url}," + f"arr={symbol_or_none(self.arr)})" + ) class MaskData(RasterData): @@ -212,5 +222,4 @@ class MaskData(RasterData): """ -class ImageData(RasterData, BaseData): - ... +class ImageData(RasterData, BaseData): ... diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/text.py b/libs/labelbox/src/labelbox/data/annotation_types/data/text.py index 20624c161..fe4c222d3 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/text.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/text.py @@ -22,6 +22,7 @@ class TextData(BaseData, _NoCoercionMixin): text (str) url (str) """ + class_name: Literal["TextData"] = "TextData" file_path: Optional[str] = None text: Optional[str] = None @@ -51,11 +52,12 @@ def value(self) -> str: raise ValueError("Must set either url, file_path or im_bytes") def set_fetch_fn(self, fn): - object.__setattr__(self, 'fetch_remote', lambda: fn(self)) + object.__setattr__(self, "fetch_remote", lambda: fn(self)) - @retry.Retry(deadline=15., - predicate=retry.if_exception_type(ConnectTimeout, - InternalServerError)) + @retry.Retry( + deadline=15.0, + predicate=retry.if_exception_type(ConnectTimeout, InternalServerError), + ) def fetch_remote(self) -> str: """ Method for accessing url. @@ -69,7 +71,7 @@ def fetch_remote(self) -> str: response.raise_for_status() return response.text - @retry.Retry(deadline=15.) + @retry.Retry(deadline=15.0) def create_url(self, signer: Callable[[bytes], str]) -> None: """ Utility for creating a url from any of the other text references. @@ -82,13 +84,14 @@ def create_url(self, signer: Callable[[bytes], str]) -> None: if self.url is not None: return self.url elif self.file_path is not None: - with open(self.file_path, 'rb') as file: + with open(self.file_path, "rb") as file: self.url = signer(file.read()) elif self.text is not None: self.url = signer(self.text.encode()) else: raise ValueError( - "One of url, im_bytes, file_path, numpy must not be None.") + "One of url, im_bytes, file_path, numpy must not be None." + ) return self.url @model_validator(mode="after") @@ -105,6 +108,8 @@ def validate_date(self, values): return self def __repr__(self) -> str: - return f"TextData(file_path={self.file_path}," \ - f"text={self.text[:30] + '...' if self.text is not None else None}," \ - f"url={self.url})" + return ( + f"TextData(file_path={self.file_path}," + f"text={self.text[:30] + '...' if self.text is not None else None}," + f"url={self.url})" + ) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/tiled_image.py b/libs/labelbox/src/labelbox/data/annotation_types/data/tiled_image.py index 5d3561ceb..adb8db549 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/tiled_image.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/tiled_image.py @@ -29,19 +29,20 @@ class EPSG(Enum): - """ Provides the EPSG for tiled image assets that are currently supported. - + """Provides the EPSG for tiled image assets that are currently supported. + SIMPLEPIXEL is Simple that can be used to obtain the pixel space coordinates >>> epsg = EPSG() """ + SIMPLEPIXEL = 1 EPSG4326 = 4326 EPSG3857 = 3857 class TiledBounds(BaseModel): - """ Bounds for a tiled image asset related to the relevant epsg. + """Bounds for a tiled image asset related to the relevant epsg. Bounds should be Point objects. @@ -51,21 +52,22 @@ class TiledBounds(BaseModel): Point(x=-99.20534818927473, y=19.400498983095076) ]) """ + epsg: EPSG bounds: List[Point] - @field_validator('bounds') + @field_validator("bounds") def validate_bounds_not_equal(cls, bounds): first_bound = bounds[0] second_bound = bounds[1] - if first_bound.x == second_bound.x or \ - first_bound.y == second_bound.y: + if first_bound.x == second_bound.x or first_bound.y == second_bound.y: raise ValueError( - f"Bounds on either axes cannot be equal, currently {bounds}") + f"Bounds on either axes cannot be equal, currently {bounds}" + ) return bounds - #validate bounds are within lat,lng range if they are EPSG4326 + # validate bounds are within lat,lng range if they are EPSG4326 @model_validator(mode="after") def validate_bounds_lat_lng(self): epsg = self.epsg @@ -74,16 +76,20 @@ def validate_bounds_lat_lng(self): if epsg == EPSG.EPSG4326: for bound in bounds: lat, lng = bound.y, bound.x - if int(lng) not in VALID_LNG_RANGE or int( - lat) not in VALID_LAT_RANGE: - raise ValueError(f"Invalid lat/lng bounds. Found {bounds}. " - f"lat must be in {VALID_LAT_RANGE}. " - f"lng must be in {VALID_LNG_RANGE}.") + if ( + int(lng) not in VALID_LNG_RANGE + or int(lat) not in VALID_LAT_RANGE + ): + raise ValueError( + f"Invalid lat/lng bounds. Found {bounds}. " + f"lat must be in {VALID_LAT_RANGE}. " + f"lng must be in {VALID_LNG_RANGE}." + ) return self class TileLayer(BaseModel): - """ Url that contains the tile layer. Must be in the format: + """Url that contains the tile layer. Must be in the format: https://c.tile.openstreetmap.org/{z}/{x}/{y}.png @@ -92,13 +98,14 @@ class TileLayer(BaseModel): name="slippy map tile" ) """ + url: str name: Optional[str] = "default" def asdict(self) -> Dict[str, str]: return {"tileLayerUrl": self.url, "name": self.name} - @field_validator('url') + @field_validator("url") def validate_url(cls, url): xyz_format = "/{z}/{x}/{y}" if xyz_format not in url: @@ -107,7 +114,7 @@ def validate_url(cls, url): class TiledImageData(BaseData): - """ Represents tiled imagery + """Represents tiled imagery If specified version is 2, converts bounds from [lng,lat] to [lat,lng] @@ -119,12 +126,13 @@ class TiledImageData(BaseData): max_native_zoom: int = None tile_size: Optional[int] version: int = 2 - alternative_layers: List[TileLayer] + alternative_layers: List[TileLayer] >>> tiled_image_data = TiledImageData(tile_layer=TileLayer, tile_bounds=TiledBounds, zoom_levels=[1, 12]) """ + tile_layer: TileLayer tile_bounds: TiledBounds alternative_layers: List[TileLayer] = [] @@ -141,9 +149,10 @@ def __post_init__(self) -> None: def asdict(self) -> Dict[str, str]: return { "tileLayerUrl": self.tile_layer.url, - "bounds": [[ - self.tile_bounds.bounds[0].x, self.tile_bounds.bounds[0].y - ], [self.tile_bounds.bounds[1].x, self.tile_bounds.bounds[1].y]], + "bounds": [ + [self.tile_bounds.bounds[0].x, self.tile_bounds.bounds[0].y], + [self.tile_bounds.bounds[1].x, self.tile_bounds.bounds[1].y], + ], "minZoom": self.zoom_levels[0], "maxZoom": self.zoom_levels[1], "maxNativeZoom": self.max_native_zoom, @@ -152,13 +161,12 @@ def asdict(self) -> Dict[str, str]: "alternativeLayers": [ layer.asdict() for layer in self.alternative_layers ], - "version": self.version + "version": self.version, } - def raster_data(self, - zoom: int = 0, - max_tiles: int = 32, - multithread=True) -> RasterData: + def raster_data( + self, zoom: int = 0, max_tiles: int = 32, multithread=True + ) -> RasterData: """Converts the tiled image asset into a RasterData object containing an np.ndarray. @@ -168,26 +176,33 @@ def raster_data(self, xstart, ystart, xend, yend = self._get_simple_image_params(zoom) elif self.tile_bounds.epsg == EPSG.EPSG4326: xstart, ystart, xend, yend = self._get_3857_image_params( - zoom, self.tile_bounds) + zoom, self.tile_bounds + ) elif self.tile_bounds.epsg == EPSG.EPSG3857: - #transform to 4326 + # transform to 4326 transformer = EPSGTransformer.create_geo_to_geo_transformer( - EPSG.EPSG3857, EPSG.EPSG4326) + EPSG.EPSG3857, EPSG.EPSG4326 + ) transforming_bounds = [ transformer(self.tile_bounds.bounds[0]), - transformer(self.tile_bounds.bounds[1]) + transformer(self.tile_bounds.bounds[1]), ] xstart, ystart, xend, yend = self._get_3857_image_params( - zoom, transforming_bounds) + zoom, transforming_bounds + ) else: raise ValueError(f"Unsupported epsg found: {self.tile_bounds.epsg}") self._validate_num_tiles(xstart, ystart, xend, yend, max_tiles) rounded_tiles, pixel_offsets = list( - zip(*[ - self._tile_to_pixel(pt) for pt in [xstart, ystart, xend, yend] - ])) + zip( + *[ + self._tile_to_pixel(pt) + for pt in [xstart, ystart, xend, yend] + ] + ) + ) image = self._fetch_image_for_bounds(*rounded_tiles, zoom, multithread) arr = self._crop_to_bounds(image, *pixel_offsets) @@ -195,13 +210,14 @@ def raster_data(self, @property def value(self) -> np.ndarray: - """Returns the value of a generated RasterData object. - """ - return self.raster_data(self.zoom_levels[0], - multithread=self.multithread).value - - def _get_simple_image_params(self, - zoom) -> Tuple[float, float, float, float]: + """Returns the value of a generated RasterData object.""" + return self.raster_data( + self.zoom_levels[0], multithread=self.multithread + ).value + + def _get_simple_image_params( + self, zoom + ) -> Tuple[float, float, float, float]: """Computes the x and y tile bounds for fetching an image that captures the entire labeling region (TiledData.bounds) given a specific zoom @@ -214,14 +230,16 @@ def _get_simple_image_params(self, self.tile_bounds.bounds[1].y, self.tile_bounds.bounds[0].y, ) - return (*[ - x * (2**(zoom)) / self.tile_size - for x in [xstart, ystart, xend, yend] - ],) + return ( + *[ + x * (2 ** (zoom)) / self.tile_size + for x in [xstart, ystart, xend, yend] + ], + ) def _get_3857_image_params( - self, zoom: int, - bounds: TiledBounds) -> Tuple[float, float, float, float]: + self, zoom: int, bounds: TiledBounds + ) -> Tuple[float, float, float, float]: """Computes the x and y tile bounds for fetching an image that captures the entire labeling region (TiledData.bounds) given a specific zoom """ @@ -237,10 +255,9 @@ def _get_3857_image_params( ystart, yend = min(ystart, yend), max(ystart, yend) return (*[pt * 2.0**zoom for pt in [xstart, ystart, xend, yend]],) - def _latlng_to_tile(self, - lat: float, - lng: float, - zoom=0) -> Tuple[float, float]: + def _latlng_to_tile( + self, lat: float, lng: float, zoom=0 + ) -> Tuple[float, float]: """Converts lat/lng to 3857 tile coordinates Formula found here: https://wiki.openstreetmap.org/wiki/Slippy_map_tilenames#lon.2Flat_to_tile_numbers_2 @@ -252,29 +269,31 @@ def _latlng_to_tile(self, return x, y def _tile_to_pixel(self, tile: float) -> Tuple[int, int]: - """Rounds a tile coordinate and reports the remainder in pixels - """ + """Rounds a tile coordinate and reports the remainder in pixels""" rounded_tile = int(tile) remainder = tile - rounded_tile pixel_offset = int(self.tile_size * remainder) return rounded_tile, pixel_offset - def _fetch_image_for_bounds(self, - x_tile_start: int, - y_tile_start: int, - x_tile_end: int, - y_tile_end: int, - zoom: int, - multithread=True) -> np.ndarray: - """Fetches the tiles and combines them into a single image. - + def _fetch_image_for_bounds( + self, + x_tile_start: int, + y_tile_start: int, + x_tile_end: int, + y_tile_end: int, + zoom: int, + multithread=True, + ) -> np.ndarray: + """Fetches the tiles and combines them into a single image. + If a tile cannot be fetched, a padding of expected tile size is instead added. """ if multithread: tiles = {} with ThreadPoolExecutor( - max_workers=TILE_DOWNLOAD_CONCURRENCY) as exc: + max_workers=TILE_DOWNLOAD_CONCURRENCY + ) as exc: for x in range(x_tile_start, x_tile_end + 1): for y in range(y_tile_start, y_tile_end + 1): tiles[(x, y)] = exc.submit(self._fetch_tile, x, y, zoom) @@ -290,8 +309,11 @@ def _fetch_image_for_bounds(self, row.append(self._fetch_tile(x, y, zoom)) except: row.append( - np.zeros(shape=(self.tile_size, self.tile_size, 3), - dtype=np.uint8)) + np.zeros( + shape=(self.tile_size, self.tile_size, 3), + dtype=np.uint8, + ) + ) rows.append(np.hstack(row)) return np.vstack(rows) @@ -331,19 +353,27 @@ def invert_point(pt): x_px_end, y_px_end = invert_point(x_px_end), invert_point(y_px_end) return image[y_px_start:y_px_end, x_px_start:x_px_end, :] - def _validate_num_tiles(self, xstart: float, ystart: float, xend: float, - yend: float, max_tiles: int): + def _validate_num_tiles( + self, + xstart: float, + ystart: float, + xend: float, + yend: float, + max_tiles: int, + ): """Calculates the number of expected tiles we would fetch. If this is greater than the number of max tiles, raise an error. """ total_n_tiles = (yend - ystart + 1) * (xend - xstart + 1) if total_n_tiles > max_tiles: - raise ValueError(f"Requested zoom results in {total_n_tiles} tiles." - f"Max allowed tiles are {max_tiles}" - f"Increase max tiles or reduce zoom level.") + raise ValueError( + f"Requested zoom results in {total_n_tiles} tiles." + f"Max allowed tiles are {max_tiles}" + f"Increase max tiles or reduce zoom level." + ) - @field_validator('zoom_levels') + @field_validator("zoom_levels") def validate_zoom_levels(cls, zoom_levels): if zoom_levels[0] > zoom_levels[1]: raise ValueError( @@ -356,8 +386,9 @@ class EPSGTransformer(BaseModel): """Transformer class between different EPSG's. Useful when wanting to project in different formats. """ + transformer: Any - model_config = ConfigDict(arbitrary_types_allowed = True) + model_config = ConfigDict(arbitrary_types_allowed=True) @staticmethod def _is_simple(epsg: EPSG) -> bool: @@ -366,7 +397,7 @@ def _is_simple(epsg: EPSG) -> bool: @staticmethod def _get_ranges(bounds: np.ndarray) -> Tuple[int, int]: """helper function to get the range between bounds. - + returns a tuple (x_range, y_range)""" x_range = np.max(bounds[:, 0]) - np.min(bounds[:, 0]) y_range = np.max(bounds[:, 1]) - np.min(bounds[:, 1]) @@ -374,90 +405,107 @@ def _get_ranges(bounds: np.ndarray) -> Tuple[int, int]: @staticmethod def _min_max_x_y(bounds: np.ndarray) -> Tuple[int, int, int, int]: - """returns the min x, max x, min y, max y of a numpy array - """ - return np.min(bounds[:, 0]), np.max(bounds[:, 0]), np.min( - bounds[:, 1]), np.max(bounds[:, 1]) + """returns the min x, max x, min y, max y of a numpy array""" + return ( + np.min(bounds[:, 0]), + np.max(bounds[:, 0]), + np.min(bounds[:, 1]), + np.max(bounds[:, 1]), + ) @classmethod - def geo_and_pixel(cls, - src_epsg, - pixel_bounds: TiledBounds, - geo_bounds: TiledBounds, - zoom=0) -> Callable: + def geo_and_pixel( + cls, + src_epsg, + pixel_bounds: TiledBounds, + geo_bounds: TiledBounds, + zoom=0, + ) -> Callable: """method to change from one projection to simple projection""" pixel_bounds = pixel_bounds.bounds geo_bounds_epsg = geo_bounds.epsg geo_bounds = geo_bounds.bounds - local_bounds = np.array([(point.x, point.y) for point in pixel_bounds], - dtype=int) - #convert geo bounds to pixel bounds. assumes geo bounds are in wgs84/EPS4326 per leaflet - global_bounds = np.array([ - PygeoPoint.from_latitude_longitude(latitude=point.y, - longitude=point.x).pixels(zoom) - for point in geo_bounds - ]) + local_bounds = np.array( + [(point.x, point.y) for point in pixel_bounds], dtype=int + ) + # convert geo bounds to pixel bounds. assumes geo bounds are in wgs84/EPS4326 per leaflet + global_bounds = np.array( + [ + PygeoPoint.from_latitude_longitude( + latitude=point.y, longitude=point.x + ).pixels(zoom) + for point in geo_bounds + ] + ) - #get the range of pixels for both sets of bounds to use as a multiplification factor + # get the range of pixels for both sets of bounds to use as a multiplification factor local_x_range, local_y_range = cls._get_ranges(bounds=local_bounds) global_x_range, global_y_range = cls._get_ranges(bounds=global_bounds) if src_epsg == EPSG.SIMPLEPIXEL: def transform(x: int, y: int) -> Callable[[int, int], Transformer]: - scaled_xy = (x * (global_x_range) / (local_x_range), - y * (global_y_range) / (local_y_range)) + scaled_xy = ( + x * (global_x_range) / (local_x_range), + y * (global_y_range) / (local_y_range), + ) minx, _, miny, _ = cls._min_max_x_y(bounds=global_bounds) x, y = map(lambda i, j: i + j, scaled_xy, (minx, miny)) - point = PygeoPoint.from_pixel(pixel_x=x, pixel_y=y, - zoom=zoom).latitude_longitude - #convert to the desired epsg - return Transformer.from_crs(EPSG.EPSG4326.value, - geo_bounds_epsg.value, - always_xy=True).transform( - point[1], point[0]) + point = PygeoPoint.from_pixel( + pixel_x=x, pixel_y=y, zoom=zoom + ).latitude_longitude + # convert to the desired epsg + return Transformer.from_crs( + EPSG.EPSG4326.value, geo_bounds_epsg.value, always_xy=True + ).transform(point[1], point[0]) return transform - #handles 4326 from lat,lng + # handles 4326 from lat,lng elif src_epsg == EPSG.EPSG4326: def transform(x: int, y: int) -> Callable[[int, int], Transformer]: point_in_px = PygeoPoint.from_latitude_longitude( - latitude=y, longitude=x).pixels(zoom) + latitude=y, longitude=x + ).pixels(zoom) minx, _, miny, _ = cls._min_max_x_y(global_bounds) x, y = map(lambda i, j: i - j, point_in_px, (minx, miny)) - return (x * (local_x_range) / (global_x_range), - y * (local_y_range) / (global_y_range)) + return ( + x * (local_x_range) / (global_x_range), + y * (local_y_range) / (global_y_range), + ) return transform - #handles 3857 from meters + # handles 3857 from meters elif src_epsg == EPSG.EPSG3857: def transform(x: int, y: int) -> Callable[[int, int], Transformer]: - point_in_px = PygeoPoint.from_meters(meter_y=y, - meter_x=x).pixels(zoom) + point_in_px = PygeoPoint.from_meters( + meter_y=y, meter_x=x + ).pixels(zoom) minx, _, miny, _ = cls._min_max_x_y(global_bounds) x, y = map(lambda i, j: i - j, point_in_px, (minx, miny)) - return (x * (local_x_range) / (global_x_range), - y * (local_y_range) / (global_y_range)) + return ( + x * (local_x_range) / (global_x_range), + y * (local_y_range) / (global_y_range), + ) return transform @classmethod def create_geo_to_geo_transformer( - cls, src_epsg: EPSG, - tgt_epsg: EPSG) -> Callable[[int, int], Transformer]: - """method to change from one projection to another projection. + cls, src_epsg: EPSG, tgt_epsg: EPSG + ) -> Callable[[int, int], Transformer]: + """method to change from one projection to another projection. supports EPSG transformations not Simple. """ @@ -466,36 +514,45 @@ def create_geo_to_geo_transformer( f"Cannot be used for Simple transformations. Found {src_epsg} and {tgt_epsg}" ) - return EPSGTransformer(transformer=Transformer.from_crs( - src_epsg.value, tgt_epsg.value, always_xy=True).transform) + return EPSGTransformer( + transformer=Transformer.from_crs( + src_epsg.value, tgt_epsg.value, always_xy=True + ).transform + ) @classmethod def create_geo_to_pixel_transformer( - cls, - src_epsg, - pixel_bounds: TiledBounds, - geo_bounds: TiledBounds, - zoom=0) -> Callable[[int, int], Transformer]: + cls, + src_epsg, + pixel_bounds: TiledBounds, + geo_bounds: TiledBounds, + zoom=0, + ) -> Callable[[int, int], Transformer]: """method to change from a geo projection to Simple""" - transform_function = cls.geo_and_pixel(src_epsg=src_epsg, - pixel_bounds=pixel_bounds, - geo_bounds=geo_bounds, - zoom=zoom) + transform_function = cls.geo_and_pixel( + src_epsg=src_epsg, + pixel_bounds=pixel_bounds, + geo_bounds=geo_bounds, + zoom=zoom, + ) return EPSGTransformer(transformer=transform_function) @classmethod def create_pixel_to_geo_transformer( - cls, - src_epsg, - pixel_bounds: TiledBounds, - geo_bounds: TiledBounds, - zoom=0) -> Callable[[int, int], Transformer]: + cls, + src_epsg, + pixel_bounds: TiledBounds, + geo_bounds: TiledBounds, + zoom=0, + ) -> Callable[[int, int], Transformer]: """method to change from a geo projection to Simple""" - transform_function = cls.geo_and_pixel(src_epsg=src_epsg, - pixel_bounds=pixel_bounds, - geo_bounds=geo_bounds, - zoom=zoom) + transform_function = cls.geo_and_pixel( + src_epsg=src_epsg, + pixel_bounds=pixel_bounds, + geo_bounds=geo_bounds, + zoom=zoom, + ) return EPSGTransformer(transformer=transform_function) def _get_point_obj(self, point) -> Point: @@ -513,9 +570,12 @@ def __call__( return Line(points=[self._get_point_obj(p) for p in shape.points]) if isinstance(shape, Polygon): return Polygon( - points=[self._get_point_obj(p) for p in shape.points]) + points=[self._get_point_obj(p) for p in shape.points] + ) if isinstance(shape, Rectangle): - return Rectangle(start=self._get_point_obj(shape.start), - end=self._get_point_obj(shape.end)) + return Rectangle( + start=self._get_point_obj(shape.start), + end=self._get_point_obj(shape.end), + ) else: - raise ValueError(f"Unsupported type found: {type(shape)}") \ No newline at end of file + raise ValueError(f"Unsupported type found: {type(shape)}") diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/video.py b/libs/labelbox/src/labelbox/data/annotation_types/data/video.py index 5d7804860..581801036 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/video.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/video.py @@ -21,11 +21,12 @@ class VideoData(BaseData): """ Represents video """ + file_path: Optional[str] = None url: Optional[str] = None - frames: Optional[Dict[int, TypedArray[Literal['uint8']]]] = None + frames: Optional[Dict[int, TypedArray[Literal["uint8"]]]] = None # Required for discriminating between data types - model_config = ConfigDict(extra = "forbid") + model_config = ConfigDict(extra="forbid") def load_frames(self, overwrite: bool = False) -> None: """ @@ -48,9 +49,7 @@ def value(self): return self.frame_generator() def frame_generator( - self, - cache_frames=False, - download_dir='/tmp' + self, cache_frames=False, download_dir="/tmp" ) -> Generator[Tuple[int, np.ndarray], None, None]: """ A generator for accessing individual frames in a video. @@ -91,9 +90,9 @@ def __getitem__(self, idx: int) -> np.ndarray: return self.frames[idx] def set_fetch_fn(self, fn): - object.__setattr__(self, 'fetch_remote', lambda: fn(self)) + object.__setattr__(self, "fetch_remote", lambda: fn(self)) - @retry.Retry(deadline=15.) + @retry.Retry(deadline=15.0) def fetch_remote(self, local_path) -> None: """ Method for downloading data from self.url @@ -106,7 +105,7 @@ def fetch_remote(self, local_path) -> None: """ urllib.request.urlretrieve(self.url, local_path) - @retry.Retry(deadline=15.) + @retry.Retry(deadline=15.0) def create_url(self, signer: Callable[[bytes], str]) -> None: """ Utility for creating a url from any of the other video references. @@ -119,7 +118,7 @@ def create_url(self, signer: Callable[[bytes], str]) -> None: if self.url is not None: return self.url elif self.file_path is not None: - with open(self.file_path, 'rb') as file: + with open(self.file_path, "rb") as file: self.url = signer(file.read()) elif self.frames is not None: self.file_path = self.frames_to_video(self.frames) @@ -128,10 +127,9 @@ def create_url(self, signer: Callable[[bytes], str]) -> None: raise ValueError("One of url, file_path, frames must not be None.") return self.url - def frames_to_video(self, - frames: Dict[int, np.ndarray], - fps=20, - save_dir='/tmp') -> str: + def frames_to_video( + self, frames: Dict[int, np.ndarray], fps=20, save_dir="/tmp" + ) -> str: """ Compresses the data by converting a set of individual frames to a single video. @@ -141,9 +139,12 @@ def frames_to_video(self, for key in frames.keys(): frame = frames[key] if out is None: - out = cv2.VideoWriter(file_path, - cv2.VideoWriter_fourcc(*'MP4V'), fps, - frame.shape[:2]) + out = cv2.VideoWriter( + file_path, + cv2.VideoWriter_fourcc(*"MP4V"), + fps, + frame.shape[:2], + ) out.write(frame) if out is None: return @@ -165,6 +166,8 @@ def validate_data(self): return self def __repr__(self) -> str: - return f"VideoData(file_path={self.file_path}," \ - f"frames={'...' if self.frames is not None else None}," \ - f"url={self.url})" + return ( + f"VideoData(file_path={self.file_path}," + f"frames={'...' if self.frames is not None else None}," + f"url={self.url})" + ) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/feature.py b/libs/labelbox/src/labelbox/data/annotation_types/feature.py index 836817aeb..5b4591abc 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/feature.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/feature.py @@ -9,6 +9,7 @@ class FeatureSchema(BaseModel): Could be a annotation, a subclass, or an option. Schema ids might not be known when constructing these objects so both a name and schema id are valid. """ + name: Optional[str] = None feature_schema_id: Optional[Cuid] = None diff --git a/libs/labelbox/src/labelbox/data/annotation_types/geometry/geometry.py b/libs/labelbox/src/labelbox/data/annotation_types/geometry/geometry.py index acdfa94c2..7b5b42cd5 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/geometry/geometry.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/geometry/geometry.py @@ -9,23 +9,34 @@ class Geometry(BaseModel, ABC): - """Abstract base class for geometry objects - """ + """Abstract base class for geometry objects""" + extra: Dict[str, Any] = {} @property def shapely( - self - ) -> Union[geom.Point, geom.LineString, geom.Polygon, geom.MultiPoint, - geom.MultiLineString, geom.MultiPolygon]: + self, + ) -> Union[ + geom.Point, + geom.LineString, + geom.Polygon, + geom.MultiPoint, + geom.MultiLineString, + geom.MultiPolygon, + ]: return geom.shape(self.geometry) - def get_or_create_canvas(self, height: Optional[int], width: Optional[int], - canvas: Optional[np.ndarray]) -> np.ndarray: + def get_or_create_canvas( + self, + height: Optional[int], + width: Optional[int], + canvas: Optional[np.ndarray], + ) -> np.ndarray: if canvas is None: if height is None or width is None: raise ValueError( - "Must either provide canvas or height and width") + "Must either provide canvas or height and width" + ) canvas = np.zeros((height, width, 3), dtype=np.uint8) canvas = np.ascontiguousarray(canvas) return canvas @@ -36,10 +47,12 @@ def geometry(self) -> geojson: pass @abstractmethod - def draw(self, - height: Optional[int] = None, - width: Optional[int] = None, - canvas: Optional[np.ndarray] = None, - color: Optional[Union[int, Tuple[int, int, int]]] = None, - thickness: Optional[int] = 1) -> np.ndarray: + def draw( + self, + height: Optional[int] = None, + width: Optional[int] = None, + canvas: Optional[np.ndarray] = None, + color: Optional[Union[int, Tuple[int, int, int]]] = None, + thickness: Optional[int] = 1, + ) -> np.ndarray: pass diff --git a/libs/labelbox/src/labelbox/data/annotation_types/geometry/line.py b/libs/labelbox/src/labelbox/data/annotation_types/geometry/line.py index fcd31b4e7..d8ea52f0c 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/geometry/line.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/geometry/line.py @@ -11,6 +11,7 @@ from pydantic import field_validator + class Line(Geometry): """Line annotation @@ -20,30 +21,36 @@ class Line(Geometry): >>> Line(points = [Point(x=3,y=4), Point(x=3,y=5)]) """ + points: List[Point] @property def geometry(self) -> geojson.MultiLineString: return geojson.MultiLineString( - [[[point.x, point.y] for point in self.points]]) + [[[point.x, point.y] for point in self.points]] + ) @classmethod def from_shapely(cls, shapely_obj: SLineString) -> "Line": """Transforms a shapely object.""" if not isinstance(shapely_obj, SLineString): raise TypeError( - f"Expected Shapely Line. Got {shapely_obj.geom_type}") + f"Expected Shapely Line. Got {shapely_obj.geom_type}" + ) - obj_coords = shapely_obj.__geo_interface__['coordinates'] + obj_coords = shapely_obj.__geo_interface__["coordinates"] return Line( - points=[Point(x=coords[0], y=coords[1]) for coords in obj_coords]) - - def draw(self, - height: Optional[int] = None, - width: Optional[int] = None, - canvas: Optional[np.ndarray] = None, - color: Union[int, Tuple[int, int, int]] = (255, 255, 255), - thickness: int = 1) -> np.ndarray: + points=[Point(x=coords[0], y=coords[1]) for coords in obj_coords] + ) + + def draw( + self, + height: Optional[int] = None, + width: Optional[int] = None, + canvas: Optional[np.ndarray] = None, + color: Union[int, Tuple[int, int, int]] = (255, 255, 255), + thickness: int = 1, + ) -> np.ndarray: """ Draw the line onto a 3d mask Args: @@ -57,14 +64,12 @@ def draw(self, numpy array representing the mask with the line drawn on it. """ canvas = self.get_or_create_canvas(height, width, canvas) - pts = np.array(self.geometry['coordinates']).astype(np.int32) - return cv2.polylines(canvas, - pts, - False, - color=color, - thickness=thickness) - - @field_validator('points') + pts = np.array(self.geometry["coordinates"]).astype(np.int32) + return cv2.polylines( + canvas, pts, False, color=color, thickness=thickness + ) + + @field_validator("points") def is_geom_valid(cls, points): if len(points) < 2: raise ValueError( diff --git a/libs/labelbox/src/labelbox/data/annotation_types/geometry/mask.py b/libs/labelbox/src/labelbox/data/annotation_types/geometry/mask.py index 39051182f..0d870f24f 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/geometry/mask.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/geometry/mask.py @@ -40,21 +40,22 @@ class Mask(Geometry): @property def geometry(self) -> Dict[str, Tuple[int, int, int]]: mask = self.draw(color=1) - contours, hierarchy = cv2.findContours(image=mask, - mode=cv2.RETR_TREE, - method=cv2.CHAIN_APPROX_NONE) + contours, hierarchy = cv2.findContours( + image=mask, mode=cv2.RETR_TREE, method=cv2.CHAIN_APPROX_NONE + ) holes = [] external_contours = [] for i in range(len(contours)): if hierarchy[0, i, 3] != -1: - #determined to be a hole based on contour hierarchy + # determined to be a hole based on contour hierarchy holes.append(contours[i]) else: external_contours.append(contours[i]) external_polygons = self._extract_polygons_from_contours( - external_contours) + external_contours + ) holes = self._extract_polygons_from_contours(holes) if not external_polygons.is_valid: @@ -65,12 +66,14 @@ def geometry(self) -> Dict[str, Tuple[int, int, int]]: return external_polygons.difference(holes).__geo_interface__ - def draw(self, - height: Optional[int] = None, - width: Optional[int] = None, - canvas: Optional[np.ndarray] = None, - color: Optional[Union[int, Tuple[int, int, int]]] = None, - thickness=None) -> np.ndarray: + def draw( + self, + height: Optional[int] = None, + width: Optional[int] = None, + canvas: Optional[np.ndarray] = None, + color: Optional[Union[int, Tuple[int, int, int]]] = None, + thickness=None, + ) -> np.ndarray: """Converts the Mask object into a numpy array Args: @@ -91,16 +94,20 @@ def draw(self, mask = np.alltrue(mask == self.color, axis=2).astype(np.uint8) if height is not None or width is not None: - mask = cv2.resize(mask, - (width or mask.shape[1], height or mask.shape[0])) + mask = cv2.resize( + mask, (width or mask.shape[1], height or mask.shape[0]) + ) dims = [mask.shape[0], mask.shape[1]] color = color or self.color if isinstance(color, (tuple, list)): dims = dims + [len(color)] - canvas = canvas if canvas is not None else np.zeros(tuple(dims), - dtype=np.uint8) + canvas = ( + canvas + if canvas is not None + else np.zeros(tuple(dims), dtype=np.uint8) + ) canvas[mask.astype(bool)] = color return canvas @@ -122,7 +129,7 @@ def create_url(self, signer: Callable[[bytes], str]) -> str: """ return self.mask.create_url(signer) - @field_validator('color') + @field_validator("color") def is_valid_color(cls, color): if isinstance(color, (tuple, list)): if len(color) == 1: @@ -137,6 +144,7 @@ def is_valid_color(cls, color): ) elif not (0 <= color <= 255): raise ValueError( - f"All rgb colors must be between 0 and 255. Found : {color}") + f"All rgb colors must be between 0 and 255. Found : {color}" + ) return color diff --git a/libs/labelbox/src/labelbox/data/annotation_types/geometry/point.py b/libs/labelbox/src/labelbox/data/annotation_types/geometry/point.py index c3f736e76..c801628f9 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/geometry/point.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/geometry/point.py @@ -18,6 +18,7 @@ class Point(Geometry): y (float) """ + x: float y: float @@ -30,17 +31,20 @@ def from_shapely(cls, shapely_obj: SPoint) -> "Point": """Transforms a shapely object.""" if not isinstance(shapely_obj, SPoint): raise TypeError( - f"Expected Shapely Point. Got {shapely_obj.geom_type}") + f"Expected Shapely Point. Got {shapely_obj.geom_type}" + ) - obj_coords = shapely_obj.__geo_interface__['coordinates'] + obj_coords = shapely_obj.__geo_interface__["coordinates"] return Point(x=obj_coords[0], y=obj_coords[1]) - def draw(self, - height: Optional[int] = None, - width: Optional[int] = None, - canvas: Optional[np.ndarray] = None, - color: Union[int, Tuple[int, int, int]] = (255, 255, 255), - thickness: int = 10) -> np.ndarray: + def draw( + self, + height: Optional[int] = None, + width: Optional[int] = None, + canvas: Optional[np.ndarray] = None, + color: Union[int, Tuple[int, int, int]] = (255, 255, 255), + thickness: int = 10, + ) -> np.ndarray: """ Draw the point onto a 3d mask Args: @@ -54,7 +58,10 @@ def draw(self, numpy array representing the mask with the point drawn on it. """ canvas = self.get_or_create_canvas(height, width, canvas) - return cv2.circle(canvas, (int(self.x), int(self.y)), - radius=thickness, - color=color, - thickness=-1) + return cv2.circle( + canvas, + (int(self.x), int(self.y)), + radius=thickness, + color=color, + thickness=-1, + ) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/geometry/polygon.py b/libs/labelbox/src/labelbox/data/annotation_types/geometry/polygon.py index 96e1f0c94..9785e7ab4 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/geometry/polygon.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/geometry/polygon.py @@ -25,6 +25,7 @@ class Polygon(Geometry): point is added to close it. """ + points: List[Point] @property @@ -36,20 +37,24 @@ def geometry(self) -> geojson.Polygon: @classmethod def from_shapely(cls, shapely_obj: SPolygon) -> "Polygon": """Transforms a shapely object.""" - #we only consider 0th index because we only allow for filled polygons + # we only consider 0th index because we only allow for filled polygons if not isinstance(shapely_obj, SPolygon): raise TypeError( - f"Expected Shapely Polygon. Got {shapely_obj.geom_type}") - obj_coords = shapely_obj.__geo_interface__['coordinates'][0] + f"Expected Shapely Polygon. Got {shapely_obj.geom_type}" + ) + obj_coords = shapely_obj.__geo_interface__["coordinates"][0] return Polygon( - points=[Point(x=coords[0], y=coords[1]) for coords in obj_coords]) - - def draw(self, - height: Optional[int] = None, - width: Optional[int] = None, - canvas: Optional[np.ndarray] = None, - color: Union[int, Tuple[int, int, int]] = (255, 255, 255), - thickness: int = -1) -> np.ndarray: + points=[Point(x=coords[0], y=coords[1]) for coords in obj_coords] + ) + + def draw( + self, + height: Optional[int] = None, + width: Optional[int] = None, + canvas: Optional[np.ndarray] = None, + color: Union[int, Tuple[int, int, int]] = (255, 255, 255), + thickness: int = -1, + ) -> np.ndarray: """ Draw the polygon onto a 3d mask Args: @@ -63,12 +68,12 @@ def draw(self, numpy array representing the mask with the polygon drawn on it. """ canvas = self.get_or_create_canvas(height, width, canvas) - pts = np.array(self.geometry['coordinates']).astype(np.int32) + pts = np.array(self.geometry["coordinates"]).astype(np.int32) if thickness == -1: return cv2.fillPoly(canvas, pts, color) return cv2.polylines(canvas, pts, True, color, thickness) - @field_validator('points') + @field_validator("points") def is_geom_valid(cls, points): if len(points) < 3: raise ValueError( diff --git a/libs/labelbox/src/labelbox/data/annotation_types/geometry/rectangle.py b/libs/labelbox/src/labelbox/data/annotation_types/geometry/rectangle.py index 3c43d44ba..5cabf0957 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/geometry/rectangle.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/geometry/rectangle.py @@ -20,43 +20,52 @@ class Rectangle(Geometry): start (Point): Top left coordinate of the rectangle end (Point): Bottom right coordinate of the rectangle """ + start: Point end: Point @property def geometry(self) -> geojson.geometry.Geometry: - return geojson.Polygon([[ - [self.start.x, self.start.y], - [self.start.x, self.end.y], - [self.end.x, self.end.y], - [self.end.x, self.start.y], - [self.start.x, self.start.y], - ]]) + return geojson.Polygon( + [ + [ + [self.start.x, self.start.y], + [self.start.x, self.end.y], + [self.end.x, self.end.y], + [self.end.x, self.start.y], + [self.start.x, self.start.y], + ] + ] + ) @classmethod def from_shapely(cls, shapely_obj: SPolygon) -> "Rectangle": """Transforms a shapely object. - + If the provided shape is a non-rectangular polygon, a rectangle will be returned based on the min and max x,y values.""" if not isinstance(shapely_obj, SPolygon): raise TypeError( - f"Expected Shapely Polygon. Got {shapely_obj.geom_type}") + f"Expected Shapely Polygon. Got {shapely_obj.geom_type}" + ) min_x, min_y, max_x, max_y = shapely_obj.bounds start = [min_x, min_y] end = [max_x, max_y] - return Rectangle(start=Point(x=start[0], y=start[1]), - end=Point(x=end[0], y=end[1])) - - def draw(self, - height: Optional[int] = None, - width: Optional[int] = None, - canvas: Optional[np.ndarray] = None, - color: Union[int, Tuple[int, int, int]] = (255, 255, 255), - thickness: int = -1) -> np.ndarray: + return Rectangle( + start=Point(x=start[0], y=start[1]), end=Point(x=end[0], y=end[1]) + ) + + def draw( + self, + height: Optional[int] = None, + width: Optional[int] = None, + canvas: Optional[np.ndarray] = None, + color: Union[int, Tuple[int, int, int]] = (255, 255, 255), + thickness: int = -1, + ) -> np.ndarray: """ Draw the rectangle onto a 3d mask Args: @@ -70,7 +79,7 @@ def draw(self, numpy array representing the mask with the rectangle drawn on it. """ canvas = self.get_or_create_canvas(height, width, canvas) - pts = np.array(self.geometry['coordinates']).astype(np.int32) + pts = np.array(self.geometry["coordinates"]).astype(np.int32) if thickness == -1: return cv2.fillPoly(canvas, pts, color) return cv2.polylines(canvas, pts, True, color, thickness) @@ -82,9 +91,9 @@ def from_xyhw(cls, x: float, y: float, h: float, w: float) -> "Rectangle": class RectangleUnit(Enum): - INCHES = 'INCHES' - PIXELS = 'PIXELS' - POINTS = 'POINTS' + INCHES = "INCHES" + PIXELS = "PIXELS" + POINTS = "POINTS" class DocumentRectangle(Rectangle): @@ -103,5 +112,6 @@ class DocumentRectangle(Rectangle): page (int): Page number of the document unit (RectangleUnits): Units of the rectangle """ + page: int unit: RectangleUnit diff --git a/libs/labelbox/src/labelbox/data/annotation_types/label.py b/libs/labelbox/src/labelbox/data/annotation_types/label.py index 973e9260f..c21a0ef8c 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/label.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/label.py @@ -3,14 +3,28 @@ import warnings import labelbox -from labelbox.data.annotation_types.data.generic_data_row_data import GenericDataRowData +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) from labelbox.data.annotation_types.data.tiled_image import TiledImageData from labelbox.schema import ontology from .annotation import ClassificationAnnotation, ObjectAnnotation from .relationship import RelationshipAnnotation from .llm_prompt_response.prompt import PromptClassificationAnnotation from .classification import ClassificationAnswer -from .data import AudioData, ConversationData, DicomData, DocumentData, HTMLData, ImageData, TextData, VideoData, LlmPromptCreationData, LlmPromptResponseCreationData, LlmResponseCreationData +from .data import ( + AudioData, + ConversationData, + DicomData, + DocumentData, + HTMLData, + ImageData, + TextData, + VideoData, + LlmPromptCreationData, + LlmPromptResponseCreationData, + LlmResponseCreationData, +) from .geometry import Mask from .metrics import ScalarMetric, ConfusionMatrixMetric from .types import Cuid @@ -20,10 +34,21 @@ from ..ontology import get_feature_schema_lookup from pydantic import BaseModel, field_validator, model_serializer -DataType = Union[VideoData, ImageData, TextData, TiledImageData, AudioData, - ConversationData, DicomData, DocumentData, HTMLData, - LlmPromptCreationData, LlmPromptResponseCreationData, - LlmResponseCreationData, GenericDataRowData] +DataType = Union[ + VideoData, + ImageData, + TextData, + TiledImageData, + AudioData, + ConversationData, + DicomData, + DocumentData, + HTMLData, + LlmPromptCreationData, + LlmPromptResponseCreationData, + LlmResponseCreationData, + GenericDataRowData, +] class Label(BaseModel): @@ -41,17 +66,26 @@ class Label(BaseModel): Args: uid: Optional Label Id in Labelbox - data: Data of Label, Image, Video, Text or dict with a single key uid | global_key | external_id. + data: Data of Label, Image, Video, Text or dict with a single key uid | global_key | external_id. Note use of classes as data is deprecated. Use GenericDataRowData or dict with a single key instead. annotations: List of Annotations in the label extra: additional context """ + uid: Optional[Cuid] = None data: DataType - annotations: List[Union[ClassificationAnnotation, ObjectAnnotation, - VideoMaskAnnotation, ScalarMetric, - ConfusionMatrixMetric, RelationshipAnnotation, - PromptClassificationAnnotation, MessageEvaluationTaskAnnotation]] = [] + annotations: List[ + Union[ + ClassificationAnnotation, + ObjectAnnotation, + VideoMaskAnnotation, + ScalarMetric, + ConfusionMatrixMetric, + RelationshipAnnotation, + PromptClassificationAnnotation, + MessageEvaluationTaskAnnotation, + ] + ] = [] extra: Dict[str, Any] = {} is_benchmark_reference: Optional[bool] = False @@ -64,7 +98,8 @@ def validate_data(cls, data): else: warnings.warn( f"Using {type(data).__name__} class for label.data is deprecated. " - "Use a dict or an instance of GenericDataRowData instead.") + "Use a dict or an instance of GenericDataRowData instead." + ) return data def object_annotations(self) -> List[ObjectAnnotation]: @@ -75,18 +110,20 @@ def classification_annotations(self) -> List[ClassificationAnnotation]: def _get_annotations_by_type(self, annotation_type): return [ - annot for annot in self.annotations + annot + for annot in self.annotations if isinstance(annot, annotation_type) ] def frame_annotations( - self + self, ) -> Dict[str, Union[VideoObjectAnnotation, VideoClassificationAnnotation]]: frame_dict = defaultdict(list) for annotation in self.annotations: if isinstance( - annotation, - (VideoObjectAnnotation, VideoClassificationAnnotation)): + annotation, + (VideoObjectAnnotation, VideoClassificationAnnotation), + ): frame_dict[annotation.frame].append(annotation) return frame_dict @@ -128,8 +165,9 @@ def add_url_to_masks(self, signer) -> "Label": mask.create_url(signer) return self - def create_data_row(self, dataset: "labelbox.Dataset", - signer: Callable[[bytes], str]) -> "Label": + def create_data_row( + self, dataset: "labelbox.Dataset", signer: Callable[[bytes], str] + ) -> "Label": """ Creates a data row and adds to the given dataset. Updates the label's data object to have the same external_id and uid as the data row. @@ -140,9 +178,9 @@ def create_data_row(self, dataset: "labelbox.Dataset", Returns: Label with updated references to new data row """ - args = {'row_data': self.data.create_url(signer)} + args = {"row_data": self.data.create_url(signer)} if self.data.external_id is not None: - args.update({'external_id': self.data.external_id}) + args.update({"external_id": self.data.external_id}) if self.data.uid is None: data_row = dataset.create_data_row(**args) @@ -151,7 +189,8 @@ def create_data_row(self, dataset: "labelbox.Dataset", return self def assign_feature_schema_ids( - self, ontology_builder: ontology.OntologyBuilder) -> "Label": + self, ontology_builder: ontology.OntologyBuilder + ) -> "Label": """ Adds schema ids to all FeatureSchema objects in the Labels. @@ -162,11 +201,14 @@ def assign_feature_schema_ids( Note: You can now import annotations using names directly without having to lookup schema_ids """ - warnings.warn("This method is deprecated and will be " - "removed in a future release. Feature schema ids" - " are no longer required for importing.") + warnings.warn( + "This method is deprecated and will be " + "removed in a future release. Feature schema ids" + " are no longer required for importing." + ) tool_lookup, classification_lookup = get_feature_schema_lookup( - ontology_builder) + ontology_builder + ) for annotation in self.annotations: if isinstance(annotation, ClassificationAnnotation): self._assign_or_raise(annotation, classification_lookup) @@ -178,7 +220,8 @@ def assign_feature_schema_ids( self._assign_option(classification, classification_lookup) else: raise TypeError( - f"Unexpected type found for annotation. {type(annotation)}") + f"Unexpected type found for annotation. {type(annotation)}" + ) return self def _assign_or_raise(self, annotation, lookup: Dict[str, str]) -> None: @@ -187,12 +230,15 @@ def _assign_or_raise(self, annotation, lookup: Dict[str, str]) -> None: feature_schema_id = lookup.get(annotation.name) if feature_schema_id is None: - raise ValueError(f"No tool matches name {annotation.name}. " - f"Must be one of {list(lookup.keys())}.") + raise ValueError( + f"No tool matches name {annotation.name}. " + f"Must be one of {list(lookup.keys())}." + ) annotation.feature_schema_id = feature_schema_id - def _assign_option(self, classification: ClassificationAnnotation, - lookup: Dict[str, str]) -> None: + def _assign_option( + self, classification: ClassificationAnnotation, lookup: Dict[str, str] + ) -> None: if isinstance(classification.value.answer, str): pass elif isinstance(classification.value.answer, ClassificationAnswer): @@ -207,10 +253,14 @@ def _assign_option(self, classification: ClassificationAnnotation, @field_validator("annotations", mode="before") def validate_union(cls, value): - supported = tuple([ - field - for field in get_args(get_args(cls.model_fields['annotations'].annotation)[0]) - ]) + supported = tuple( + [ + field + for field in get_args( + get_args(cls.model_fields["annotations"].annotation)[0] + ) + ] + ) if not isinstance(value, list): raise TypeError(f"Annotations must be a list. Found {type(value)}") prompt_count = 0 @@ -224,5 +274,6 @@ def validate_union(cls, value): prompt_count += 1 if prompt_count > 1: raise TypeError( - f"Only one prompt annotation is allowed per label") + f"Only one prompt annotation is allowed per label" + ) return value diff --git a/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/__init__.py index 7c0b63abc..4f4c0ee0e 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/__init__.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/__init__.py @@ -1,2 +1,2 @@ from .prompt import PromptText -from .prompt import PromptClassificationAnnotation \ No newline at end of file +from .prompt import PromptClassificationAnnotation diff --git a/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py index 98c0e7a69..b5a7e4fe5 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py @@ -4,7 +4,7 @@ class PromptText(ConfidenceMixin, CustomMetricsMixin, BaseModel): - """ Prompt text for LLM data generation + """Prompt text for LLM data generation >>> PromptText(answer = "some text answer", >>> confidence = 0.5, @@ -14,11 +14,13 @@ class PromptText(ConfidenceMixin, CustomMetricsMixin, BaseModel): >>> "value": 0.1 >>> }]) """ + answer: str -class PromptClassificationAnnotation(BaseAnnotation, ConfidenceMixin, - CustomMetricsMixin): +class PromptClassificationAnnotation( + BaseAnnotation, ConfidenceMixin, CustomMetricsMixin +): """Prompt annotation (non localized) >>> PromptClassificationAnnotation( @@ -30,6 +32,6 @@ class PromptClassificationAnnotation(BaseAnnotation, ConfidenceMixin, name (Optional[str]) feature_schema_id (Optional[Cuid]) value (Union[Text]) - """ + """ - value: PromptText \ No newline at end of file + value: PromptText diff --git a/libs/labelbox/src/labelbox/data/annotation_types/metrics/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/metrics/__init__.py index 2c7e45178..37750dd1f 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/metrics/__init__.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/metrics/__init__.py @@ -1,2 +1,6 @@ from .scalar import ScalarMetric, ScalarMetricAggregation, ScalarMetricValue -from .confusion_matrix import ConfusionMatrixMetric, ConfusionMatrixAggregation, ConfusionMatrixMetricValue +from .confusion_matrix import ( + ConfusionMatrixMetric, + ConfusionMatrixAggregation, + ConfusionMatrixMetricValue, +) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/metrics/base.py b/libs/labelbox/src/labelbox/data/annotation_types/metrics/base.py index 7c0636f48..0a4773a41 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/metrics/base.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/metrics/base.py @@ -1,7 +1,13 @@ from abc import ABC from typing import Dict, Optional, Any, Union -from pydantic import confloat, BaseModel, model_serializer, field_validator, error_wrappers +from pydantic import ( + confloat, + BaseModel, + model_serializer, + field_validator, + error_wrappers, +) from pydantic_core import ValidationError, InitErrorDetails ConfidenceValue = confloat(ge=0, le=1) @@ -21,15 +27,15 @@ def serialize_model(self, handler): res = handler(self) return {k: v for k, v in res.items() if v is not None} - - @field_validator('value') + @field_validator("value") def validate_value(cls, value): if isinstance(value, Dict): - if not (MIN_CONFIDENCE_SCORES <= len(value) <= - MAX_CONFIDENCE_SCORES): + if not ( + MIN_CONFIDENCE_SCORES <= len(value) <= MAX_CONFIDENCE_SCORES + ): raise ValueError( - f"Number of confidence scores must be greater than\n \ + f"Number of confidence scores must be greater than\n \ or equal to {MIN_CONFIDENCE_SCORES} and less than\n \ or equal to {MAX_CONFIDENCE_SCORES}. Found {len(value)}" - ) + ) return value diff --git a/libs/labelbox/src/labelbox/data/annotation_types/metrics/confusion_matrix.py b/libs/labelbox/src/labelbox/data/annotation_types/metrics/confusion_matrix.py index 4a346b8f4..30e2d2ed4 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/metrics/confusion_matrix.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/metrics/confusion_matrix.py @@ -9,8 +9,9 @@ Count = conint(ge=0, le=1e10) ConfusionMatrixMetricValue = Tuple[Count, Count, Count, Count] -ConfusionMatrixMetricConfidenceValue = Dict[ConfidenceValue, - ConfusionMatrixMetricValue] +ConfusionMatrixMetricConfidenceValue = Dict[ + ConfidenceValue, ConfusionMatrixMetricValue +] class ConfusionMatrixAggregation(Enum): @@ -18,7 +19,7 @@ class ConfusionMatrixAggregation(Enum): class ConfusionMatrixMetric(BaseMetric): - """ Class representing confusion matrix metrics. + """Class representing confusion matrix metrics. In the editor, this provides precision, recall, and f-scores. This should be used over multiple scalar metrics so that aggregations are accurate. @@ -28,7 +29,11 @@ class ConfusionMatrixMetric(BaseMetric): aggregation cannot be adjusted for confusion matrix metrics. """ + metric_name: str - value: Union[ConfusionMatrixMetricValue, - ConfusionMatrixMetricConfidenceValue] - aggregation: Optional[ConfusionMatrixAggregation] = ConfusionMatrixAggregation.CONFUSION_MATRIX + value: Union[ + ConfusionMatrixMetricValue, ConfusionMatrixMetricConfidenceValue + ] + aggregation: Optional[ConfusionMatrixAggregation] = ( + ConfusionMatrixAggregation.CONFUSION_MATRIX + ) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py b/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py index 560d6dcef..13d0e9748 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py @@ -18,30 +18,41 @@ class ScalarMetricAggregation(Enum): SUM = "SUM" -RESERVED_METRIC_NAMES = ('true_positive_count', 'false_positive_count', - 'true_negative_count', 'false_negative_count', - 'precision', 'recall', 'f1', 'iou') +RESERVED_METRIC_NAMES = ( + "true_positive_count", + "false_positive_count", + "true_negative_count", + "false_negative_count", + "precision", + "recall", + "f1", + "iou", +) class ScalarMetric(BaseMetric): - """ Class representing scalar metrics + """Class representing scalar metrics For backwards compatibility, metric_name is optional. The metric_name will be set to a default name in the editor if it is not set. This is not recommended and support for empty metric_name fields will be removed. aggregation will be ignored without providing a metric name. """ + metric_name: Optional[str] = None value: Union[ScalarMetricValue, ScalarMetricConfidenceValue] - aggregation: Optional[ - ScalarMetricAggregation] = ScalarMetricAggregation.ARITHMETIC_MEAN + aggregation: Optional[ScalarMetricAggregation] = ( + ScalarMetricAggregation.ARITHMETIC_MEAN + ) - @field_validator('metric_name') + @field_validator("metric_name") def validate_metric_name(cls, name: Union[str, None]): if name is None: return None clean_name = name.lower().strip() if clean_name in RESERVED_METRIC_NAMES: - raise ValueError(f"`{clean_name}` is a reserved metric name. " - "Please provide another value for `metric_name`.") + raise ValueError( + f"`{clean_name}` is a reserved metric name. " + "Please provide another value for `metric_name`." + ) return name diff --git a/libs/labelbox/src/labelbox/data/annotation_types/mmc.py b/libs/labelbox/src/labelbox/data/annotation_types/mmc.py index d3ab763cb..e2ed74d41 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/mmc.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/mmc.py @@ -10,7 +10,7 @@ class MessageInfo(_CamelCaseMixin): message_id: str model_config_name: str - + model_config = ConfigDict(protected_namespaces=()) @@ -21,7 +21,7 @@ class OrderedMessageInfo(MessageInfo): class _BaseMessageEvaluationTask(_CamelCaseMixin, ABC): format: ClassVar[str] parent_message_id: str - + model_config = ConfigDict(protected_namespaces=()) @@ -48,5 +48,8 @@ def _validate_ranked_messages(cls, v: List[OrderedMessageInfo]): class MessageEvaluationTaskAnnotation(BaseAnnotation): - value: Union[MessageSingleSelectionTask, MessageMultiSelectionTask, - MessageRankingTask] + value: Union[ + MessageSingleSelectionTask, + MessageMultiSelectionTask, + MessageRankingTask, + ] diff --git a/libs/labelbox/src/labelbox/data/annotation_types/ner/conversation_entity.py b/libs/labelbox/src/labelbox/data/annotation_types/ner/conversation_entity.py index 53b9059b9..e8bd49b56 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/ner/conversation_entity.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/ner/conversation_entity.py @@ -3,5 +3,6 @@ class ConversationEntity(TextEntity, _CamelCaseMixin): - """ Represents a text entity """ - message_id: str \ No newline at end of file + """Represents a text entity""" + + message_id: str diff --git a/libs/labelbox/src/labelbox/data/annotation_types/ner/document_entity.py b/libs/labelbox/src/labelbox/data/annotation_types/ner/document_entity.py index c2acecd7c..6a5abec23 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/ner/document_entity.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/ner/document_entity.py @@ -17,5 +17,6 @@ def validate_page(cls, v): class DocumentEntity(_CamelCaseMixin, BaseModel): - """ Represents a text entity """ + """Represents a text entity""" + text_selections: List[DocumentTextSelection] diff --git a/libs/labelbox/src/labelbox/data/annotation_types/ner/text_entity.py b/libs/labelbox/src/labelbox/data/annotation_types/ner/text_entity.py index 60764f759..ece341434 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/ner/text_entity.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/ner/text_entity.py @@ -4,16 +4,17 @@ class TextEntity(BaseModel): - """ Represents a text entity """ + """Represents a text entity""" + start: int end: int extra: Dict[str, Any] = {} @model_validator(mode="after") def validate_start_end(self, values): - if hasattr(self, 'start') and hasattr(self, 'end'): - if (isinstance(self.start, int) and - self.start > self.end): + if hasattr(self, "start") and hasattr(self, "end"): + if isinstance(self.start, int) and self.start > self.end: raise ValueError( - "Location end must be greater or equal to start") + "Location end must be greater or equal to start" + ) return self diff --git a/libs/labelbox/src/labelbox/data/annotation_types/relationship.py b/libs/labelbox/src/labelbox/data/annotation_types/relationship.py index 27a833830..b65f21d16 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/relationship.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/relationship.py @@ -1,10 +1,12 @@ from pydantic import BaseModel from enum import Enum -from labelbox.data.annotation_types.annotation import BaseAnnotation, ObjectAnnotation +from labelbox.data.annotation_types.annotation import ( + BaseAnnotation, + ObjectAnnotation, +) class Relationship(BaseModel): - class Type(Enum): UNIDIRECTIONAL = "unidirectional" BIDIRECTIONAL = "bidirectional" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/types.py b/libs/labelbox/src/labelbox/data/annotation_types/types.py index b26789aae..0a9793f8f 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/types.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/types.py @@ -9,12 +9,11 @@ Cuid = Annotated[str, StringConstraints(min_length=25, max_length=25)] -DType = TypeVar('DType') -DShape = TypeVar('DShape') +DType = TypeVar("DType") +DShape = TypeVar("DShape") class _TypedArray(np.ndarray, Generic[DType, DShape]): - @classmethod def __get_validators__(cls): yield cls.validate @@ -26,15 +25,21 @@ def validate(cls, val, field: Field): return val -if version.parse(np.__version__) >= version.parse('1.25.0'): +if version.parse(np.__version__) >= version.parse("1.25.0"): from typing import GenericAlias + TypedArray = GenericAlias(_TypedArray, (Any, DType)) -elif version.parse(np.__version__) >= version.parse('1.23.0'): +elif version.parse(np.__version__) >= version.parse("1.23.0"): from numpy._typing import _GenericAlias + TypedArray = _GenericAlias(_TypedArray, (Any, DType)) -elif version.parse('1.22.0') <= version.parse( - np.__version__) < version.parse('1.23.0'): +elif ( + version.parse("1.22.0") + <= version.parse(np.__version__) + < version.parse("1.23.0") +): from numpy.typing import _GenericAlias + TypedArray = _GenericAlias(_TypedArray, (Any, DType)) else: TypedArray = _TypedArray[Any, DType] diff --git a/libs/labelbox/src/labelbox/data/annotation_types/video.py b/libs/labelbox/src/labelbox/data/annotation_types/video.py index 79a14ec2d..cfebd7a1f 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/video.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/video.py @@ -1,13 +1,30 @@ from enum import Enum from typing import List, Optional, Tuple -from labelbox.data.annotation_types.annotation import ClassificationAnnotation, ObjectAnnotation - -from labelbox.data.annotation_types.annotation import ClassificationAnnotation, ObjectAnnotation +from labelbox.data.annotation_types.annotation import ( + ClassificationAnnotation, + ObjectAnnotation, +) + +from labelbox.data.annotation_types.annotation import ( + ClassificationAnnotation, + ObjectAnnotation, +) from labelbox.data.annotation_types.feature import FeatureSchema -from labelbox.data.mixins import ConfidenceNotSupportedMixin, CustomMetricsNotSupportedMixin +from labelbox.data.mixins import ( + ConfidenceNotSupportedMixin, + CustomMetricsNotSupportedMixin, +) from labelbox.utils import _CamelCaseMixin, is_valid_uri -from pydantic import model_validator, BaseModel, field_validator, model_serializer, Field, ConfigDict, AliasChoices +from pydantic import ( + model_validator, + BaseModel, + field_validator, + model_serializer, + Field, + ConfigDict, + AliasChoices, +) class VideoClassificationAnnotation(ClassificationAnnotation): @@ -20,12 +37,16 @@ class VideoClassificationAnnotation(ClassificationAnnotation): segment_id (Optional[Int]): Index of video segment this annotation belongs to extra (Dict[str, Any]) """ + frame: int segment_index: Optional[int] = None -class VideoObjectAnnotation(ObjectAnnotation, ConfidenceNotSupportedMixin, - CustomMetricsNotSupportedMixin): +class VideoObjectAnnotation( + ObjectAnnotation, + ConfidenceNotSupportedMixin, + CustomMetricsNotSupportedMixin, +): """Video object annotation >>> VideoObjectAnnotation( >>> keyframe=True, @@ -46,14 +67,15 @@ class VideoObjectAnnotation(ObjectAnnotation, ConfidenceNotSupportedMixin, classifications (List[ClassificationAnnotation]) = [] extra (Dict[str, Any]) """ + frame: int keyframe: bool segment_index: Optional[int] = None class GroupKey(Enum): - """Group key for DICOM annotations - """ + """Group key for DICOM annotations""" + AXIAL = "axial" SAGITTAL = "sagittal" CORONAL = "coronal" @@ -84,14 +106,19 @@ class DICOMObjectAnnotation(VideoObjectAnnotation): classifications (List[ClassificationAnnotation]) = [] extra (Dict[str, Any]) """ + group_key: GroupKey class MaskFrame(_CamelCaseMixin, BaseModel): index: int - instance_uri: Optional[str] = Field(default=None, validation_alias=AliasChoices("instanceURI", "instanceUri"), serialization_alias="instanceURI") + instance_uri: Optional[str] = Field( + default=None, + validation_alias=AliasChoices("instanceURI", "instanceUri"), + serialization_alias="instanceURI", + ) im_bytes: Optional[bytes] = None - + model_config = ConfigDict(populate_by_name=True) @model_validator(mode="after") @@ -110,43 +137,49 @@ def validate_uri(cls, v): class MaskInstance(_CamelCaseMixin, FeatureSchema): - color_rgb: Tuple[int, int, int] = Field(validation_alias=AliasChoices("colorRGB", "colorRgb"), serialization_alias="colorRGB") + color_rgb: Tuple[int, int, int] = Field( + validation_alias=AliasChoices("colorRGB", "colorRgb"), + serialization_alias="colorRGB", + ) name: str model_config = ConfigDict(populate_by_name=True) + class VideoMaskAnnotation(BaseModel): """Video mask annotation - >>> VideoMaskAnnotation( - >>> frames=[ - >>> MaskFrame(index=1, instance_uri='https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA'), - >>> MaskFrame(index=5, instance_uri='https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys1%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA'), - >>> ], - >>> instances=[ - >>> MaskInstance(color_rgb=(0, 0, 255), name="mask1"), - >>> MaskInstance(color_rgb=(0, 255, 0), name="mask2"), - >>> MaskInstance(color_rgb=(255, 0, 0), name="mask3") - >>> ] - >>> ) + >>> VideoMaskAnnotation( + >>> frames=[ + >>> MaskFrame(index=1, instance_uri='https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA'), + >>> MaskFrame(index=5, instance_uri='https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys1%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA'), + >>> ], + >>> instances=[ + >>> MaskInstance(color_rgb=(0, 0, 255), name="mask1"), + >>> MaskInstance(color_rgb=(0, 255, 0), name="mask2"), + >>> MaskInstance(color_rgb=(255, 0, 0), name="mask3") + >>> ] + >>> ) """ + frames: List[MaskFrame] instances: List[MaskInstance] class DICOMMaskAnnotation(VideoMaskAnnotation): """DICOM mask annotation - >>> DICOMMaskAnnotation( - >>> name="dicom_mask", - >>> group_key=GroupKey.AXIAL, - >>> frames=[ - >>> MaskFrame(index=1, instance_uri='https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA'), - >>> MaskFrame(index=5, instance_uri='https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys1%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA'), - >>> ], - >>> instances=[ - >>> MaskInstance(color_rgb=(0, 0, 255), name="mask1"), - >>> MaskInstance(color_rgb=(0, 255, 0), name="mask2"), - >>> MaskInstance(color_rgb=(255, 0, 0), name="mask3") - >>> ] - >>> ) + >>> DICOMMaskAnnotation( + >>> name="dicom_mask", + >>> group_key=GroupKey.AXIAL, + >>> frames=[ + >>> MaskFrame(index=1, instance_uri='https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA'), + >>> MaskFrame(index=5, instance_uri='https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys1%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA'), + >>> ], + >>> instances=[ + >>> MaskInstance(color_rgb=(0, 0, 255), name="mask1"), + >>> MaskInstance(color_rgb=(0, 255, 0), name="mask2"), + >>> MaskInstance(color_rgb=(255, 0, 0), name="mask3") + >>> ] + >>> ) """ + group_key: GroupKey diff --git a/libs/labelbox/src/labelbox/data/generator.py b/libs/labelbox/src/labelbox/data/generator.py index 891dc1315..8270a6715 100644 --- a/libs/labelbox/src/labelbox/data/generator.py +++ b/libs/labelbox/src/labelbox/data/generator.py @@ -13,9 +13,7 @@ class ThreadSafeGen: """ def __init__(self, iterable: Iterable[Any]): - """ - - """ + """ """ self.iterable = iterable self.lock = threading.Lock() @@ -70,7 +68,8 @@ def fill_queue(self): self.queue.put(value) except Exception as e: self.queue.put( - ValueError(f"Unexpected exception while filling queue: {e}")) + ValueError(f"Unexpected exception while filling queue: {e}") + ) finally: self.queue.put(None) diff --git a/libs/labelbox/src/labelbox/data/metrics/__init__.py b/libs/labelbox/src/labelbox/data/metrics/__init__.py index f99fc85a8..7085b772e 100644 --- a/libs/labelbox/src/labelbox/data/metrics/__init__.py +++ b/libs/labelbox/src/labelbox/data/metrics/__init__.py @@ -1,2 +1,5 @@ -from .confusion_matrix import confusion_matrix_metric, feature_confusion_matrix_metric +from .confusion_matrix import ( + confusion_matrix_metric, + feature_confusion_matrix_metric, +) from .iou import miou_metric, feature_miou_metric diff --git a/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/calculation.py b/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/calculation.py index 1b1fc801b..938e17f65 100644 --- a/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/calculation.py +++ b/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/calculation.py @@ -2,20 +2,37 @@ import numpy as np -from ..iou.calculation import _get_mask_pairs, _get_vector_pairs, _get_ner_pairs, miou -from ...annotation_types import (ObjectAnnotation, ClassificationAnnotation, - Mask, Geometry, Checklist, Radio, TextEntity, - ScalarMetricValue, ConfusionMatrixMetricValue) -from ..group import (get_feature_pairs, get_identifying_key, has_no_annotations, - has_no_matching_annotations) - - -def confusion_matrix(ground_truths: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - include_subclasses: bool, - iou: float) -> ConfusionMatrixMetricValue: +from ..iou.calculation import ( + _get_mask_pairs, + _get_vector_pairs, + _get_ner_pairs, + miou, +) +from ...annotation_types import ( + ObjectAnnotation, + ClassificationAnnotation, + Mask, + Geometry, + Checklist, + Radio, + TextEntity, + ScalarMetricValue, + ConfusionMatrixMetricValue, +) +from ..group import ( + get_feature_pairs, + get_identifying_key, + has_no_annotations, + has_no_matching_annotations, +) + + +def confusion_matrix( + ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], + predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], + include_subclasses: bool, + iou: float, +) -> ConfusionMatrixMetricValue: """ Computes the confusion matrix for an arbitrary set of ground truth and predicted annotations. It first computes the confusion matrix for each metric and then sums across all classes @@ -33,8 +50,9 @@ def confusion_matrix(ground_truths: List[Union[ObjectAnnotation, annotation_pairs = get_feature_pairs(ground_truths, predictions) conf_matrix = [ - feature_confusion_matrix(annotation_pair[0], annotation_pair[1], - include_subclasses, iou) + feature_confusion_matrix( + annotation_pair[0], annotation_pair[1], include_subclasses, iou + ) for annotation_pair in annotation_pairs.values() ] matrices = [matrix for matrix in conf_matrix if matrix is not None] @@ -42,10 +60,11 @@ def confusion_matrix(ground_truths: List[Union[ObjectAnnotation, def feature_confusion_matrix( - ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], - include_subclasses: bool, - iou: float) -> Optional[ConfusionMatrixMetricValue]: + ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], + predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], + include_subclasses: bool, + iou: float, +) -> Optional[ConfusionMatrixMetricValue]: """ Computes confusion matrix for all features of the same class. @@ -63,24 +82,28 @@ def feature_confusion_matrix( elif has_no_annotations(ground_truths, predictions): return None elif isinstance(predictions[0].value, Mask): - return mask_confusion_matrix(ground_truths, predictions, - include_subclasses, iou) + return mask_confusion_matrix( + ground_truths, predictions, include_subclasses, iou + ) elif isinstance(predictions[0].value, Geometry): - return vector_confusion_matrix(ground_truths, predictions, - include_subclasses, iou) + return vector_confusion_matrix( + ground_truths, predictions, include_subclasses, iou + ) elif isinstance(predictions[0].value, TextEntity): - return ner_confusion_matrix(ground_truths, predictions, - include_subclasses, iou) + return ner_confusion_matrix( + ground_truths, predictions, include_subclasses, iou + ) elif isinstance(predictions[0], ClassificationAnnotation): return classification_confusion_matrix(ground_truths, predictions) else: raise ValueError( - f"Unexpected annotation found. Found {type(predictions[0].value)}") + f"Unexpected annotation found. Found {type(predictions[0].value)}" + ) def classification_confusion_matrix( - ground_truths: List[ClassificationAnnotation], - predictions: List[ClassificationAnnotation] + ground_truths: List[ClassificationAnnotation], + predictions: List[ClassificationAnnotation], ) -> ConfusionMatrixMetricValue: """ Computes iou score for all features with the same feature schema id. @@ -97,9 +120,11 @@ def classification_confusion_matrix( if has_no_matching_annotations(ground_truths, predictions): return [0, len(predictions), 0, len(ground_truths)] - elif has_no_annotations( - ground_truths, - predictions) or len(predictions) > 1 or len(ground_truths) > 1: + elif ( + has_no_annotations(ground_truths, predictions) + or len(predictions) > 1 + or len(ground_truths) > 1 + ): # Note that we could return [0,0,0,0] but that will bloat the imports for no reason return None @@ -108,7 +133,8 @@ def classification_confusion_matrix( if type(prediction) != type(ground_truth): raise TypeError( "Classification features must be the same type to compute agreement. " - f"Found `{type(prediction)}` and `{type(ground_truth)}`") + f"Found `{type(prediction)}` and `{type(ground_truth)}`" + ) if isinstance(prediction.value, Radio): return radio_confusion_matrix(ground_truth.value, prediction.value) @@ -120,11 +146,13 @@ def classification_confusion_matrix( ) -def vector_confusion_matrix(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation], - include_subclasses: bool, - iou: float, - buffer=70.) -> Optional[ConfusionMatrixMetricValue]: +def vector_confusion_matrix( + ground_truths: List[ObjectAnnotation], + predictions: List[ObjectAnnotation], + include_subclasses: bool, + iou: float, + buffer=70.0, +) -> Optional[ConfusionMatrixMetricValue]: """ Computes confusion matrix for any vector class (point, polygon, line, rectangle). Ground truths and predictions should all belong to the same class. @@ -149,11 +177,11 @@ def vector_confusion_matrix(ground_truths: List[ObjectAnnotation], return object_pair_confusion_matrix(pairs, include_subclasses, iou) -def object_pair_confusion_matrix(pairs: List[Tuple[ObjectAnnotation, - ObjectAnnotation, - ScalarMetricValue]], - include_subclasses: bool, - iou: float) -> ConfusionMatrixMetricValue: +def object_pair_confusion_matrix( + pairs: List[Tuple[ObjectAnnotation, ObjectAnnotation, ScalarMetricValue]], + include_subclasses: bool, + iou: float, +) -> ConfusionMatrixMetricValue: """ Computes the confusion matrix for a list of object annotation pairs. Performs greedy matching of pairs. @@ -177,14 +205,22 @@ def object_pair_confusion_matrix(pairs: List[Tuple[ObjectAnnotation, prediction_ids.add(prediction_id) ground_truth_ids.add(ground_truth_id) - if agreement > iou and \ - prediction_id not in matched_predictions and \ - ground_truth_id not in matched_ground_truths: - if include_subclasses and (ground_truth.classifications or - prediction.classifications): - if miou(ground_truth.classifications, + if ( + agreement > iou + and prediction_id not in matched_predictions + and ground_truth_id not in matched_ground_truths + ): + if include_subclasses and ( + ground_truth.classifications or prediction.classifications + ): + if ( + miou( + ground_truth.classifications, prediction.classifications, - include_subclasses=False) < 1.: + include_subclasses=False, + ) + < 1.0 + ): # Incorrect if the subclasses don't 100% agree then there is no match continue matched_predictions.add(prediction_id) @@ -198,8 +234,9 @@ def object_pair_confusion_matrix(pairs: List[Tuple[ObjectAnnotation, return [tps, fps, tns, fns] -def radio_confusion_matrix(ground_truth: Radio, - prediction: Radio) -> ConfusionMatrixMetricValue: +def radio_confusion_matrix( + ground_truth: Radio, prediction: Radio +) -> ConfusionMatrixMetricValue: """ Calculates confusion between ground truth and predicted radio values @@ -220,8 +257,8 @@ def radio_confusion_matrix(ground_truth: Radio, def checklist_confusion_matrix( - ground_truth: Checklist, - prediction: Checklist) -> ConfusionMatrixMetricValue: + ground_truth: Checklist, prediction: Checklist +) -> ConfusionMatrixMetricValue: """ Calculates agreement between ground truth and predicted checklist items: @@ -246,10 +283,12 @@ def checklist_confusion_matrix( return [tps, fps, 0, fns] -def mask_confusion_matrix(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation], - include_subclasses: bool, - iou: float) -> Optional[ScalarMetricValue]: +def mask_confusion_matrix( + ground_truths: List[ObjectAnnotation], + predictions: List[ObjectAnnotation], + include_subclasses: bool, + iou: float, +) -> Optional[ScalarMetricValue]: """ Computes confusion matrix metric for two masks @@ -269,15 +308,17 @@ def mask_confusion_matrix(ground_truths: List[ObjectAnnotation], return None pairs = _get_mask_pairs(ground_truths, predictions) - return object_pair_confusion_matrix(pairs, - include_subclasses=include_subclasses, - iou=iou) + return object_pair_confusion_matrix( + pairs, include_subclasses=include_subclasses, iou=iou + ) -def ner_confusion_matrix(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation], - include_subclasses: bool, - iou: float) -> Optional[ConfusionMatrixMetricValue]: +def ner_confusion_matrix( + ground_truths: List[ObjectAnnotation], + predictions: List[ObjectAnnotation], + include_subclasses: bool, + iou: float, +) -> Optional[ConfusionMatrixMetricValue]: """Computes confusion matrix metric between two lists of TextEntity objects Args: diff --git a/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/confusion_matrix.py b/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/confusion_matrix.py index 19caab426..6d817b105 100644 --- a/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/confusion_matrix.py +++ b/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/confusion_matrix.py @@ -3,8 +3,11 @@ from labelbox.data.annotation_types import feature from labelbox.data.annotation_types.metrics import ConfusionMatrixMetric from typing import List, Optional, Union -from ...annotation_types import (Label, ObjectAnnotation, - ClassificationAnnotation) +from ...annotation_types import ( + Label, + ObjectAnnotation, + ClassificationAnnotation, +) from ..group import get_feature_pairs from .calculation import confusion_matrix @@ -12,12 +15,12 @@ import numpy as np -def confusion_matrix_metric(ground_truths: List[Union[ - ObjectAnnotation, ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - include_subclasses=False, - iou=0.5) -> List[ConfusionMatrixMetric]: +def confusion_matrix_metric( + ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], + predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], + include_subclasses=False, + iou=0.5, +) -> List[ConfusionMatrixMetric]: """ Computes confusion matrix metrics between two sets of annotations. These annotations should relate to the same data (image/video). @@ -31,11 +34,12 @@ def confusion_matrix_metric(ground_truths: List[Union[ Returns: Returns a list of ConfusionMatrixMetrics. Will be empty if there were no predictions and labels. Otherwise a single metric will be returned. """ - if not (0. < iou < 1.): + if not (0.0 < iou < 1.0): raise ValueError("iou must be between 0 and 1") - value = confusion_matrix(ground_truths, predictions, include_subclasses, - iou) + value = confusion_matrix( + ground_truths, predictions, include_subclasses, iou + ) # If both gt and preds are empty there is no metric if value is None: return [] @@ -68,39 +72,45 @@ def feature_confusion_matrix_metric( annotation_pairs = get_feature_pairs(ground_truths, predictions) metrics = [] for key in annotation_pairs: - value = feature_confusion_matrix(annotation_pairs[key][0], - annotation_pairs[key][1], - include_subclasses, iou) + value = feature_confusion_matrix( + annotation_pairs[key][0], + annotation_pairs[key][1], + include_subclasses, + iou, + ) if value is None: continue - metric_name = _get_metric_name(annotation_pairs[key][0], - annotation_pairs[key][1], iou) + metric_name = _get_metric_name( + annotation_pairs[key][0], annotation_pairs[key][1], iou + ) metrics.append( - ConfusionMatrixMetric(metric_name=metric_name, - feature_name=key, - value=value)) + ConfusionMatrixMetric( + metric_name=metric_name, feature_name=key, value=value + ) + ) return metrics -def _get_metric_name(ground_truths: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - iou: float): - +def _get_metric_name( + ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], + predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], + iou: float, +): if _is_classification(ground_truths, predictions): return "classification" return f"{int(iou*100)}pct_iou" -def _is_classification(ground_truths: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]]): +def _is_classification( + ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], + predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], +): # Check if either the prediction or label contains a classification annotation - return (len(predictions) and - isinstance(predictions[0], ClassificationAnnotation) or - len(ground_truths) and - isinstance(ground_truths[0], ClassificationAnnotation)) + return ( + len(predictions) + and isinstance(predictions[0], ClassificationAnnotation) + or len(ground_truths) + and isinstance(ground_truths[0], ClassificationAnnotation) + ) diff --git a/libs/labelbox/src/labelbox/data/metrics/group.py b/libs/labelbox/src/labelbox/data/metrics/group.py index 5579ac9ce..88f4eae8b 100644 --- a/libs/labelbox/src/labelbox/data/metrics/group.py +++ b/libs/labelbox/src/labelbox/data/metrics/group.py @@ -1,11 +1,18 @@ """ Tools for grouping features and labels so that we can compute metrics on the individual groups """ + from collections import defaultdict from typing import Dict, List, Tuple, Union from labelbox.data.annotation_types.annotation import ClassificationAnnotation -from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnswer, Radio, Text +from labelbox.data.annotation_types.classification.classification import ( + Checklist, + ClassificationAnswer, + Radio, + Text, +) + try: from typing import Literal except ImportError: @@ -17,7 +24,7 @@ def get_identifying_key( features_a: List[FeatureSchema], features_b: List[FeatureSchema] -) -> Union[Literal['name'], Literal['feature_schema_id']]: +) -> Union[Literal["name"], Literal["feature_schema_id"]]: """ Checks to make sure that features in both sets contain the same type of identifying keys. This can either be the feature name or feature schema id. @@ -30,22 +37,24 @@ def get_identifying_key( """ all_schema_ids_defined_pred, all_names_defined_pred = all_have_key( - features_a) - if (not all_schema_ids_defined_pred and not all_names_defined_pred): + features_a + ) + if not all_schema_ids_defined_pred and not all_names_defined_pred: raise ValueError("All data must have feature_schema_ids or names set") all_schema_ids_defined_gt, all_names_defined_gt = all_have_key(features_b) # Prefer name becuse the user will be able to know what it means # Schema id incase that doesn't exist. - if (all_names_defined_pred and all_names_defined_gt): - return 'name' + if all_names_defined_pred and all_names_defined_gt: + return "name" elif all_schema_ids_defined_pred and all_schema_ids_defined_gt: - return 'feature_schema_id' + return "feature_schema_id" else: raise ValueError( "Ground truth and prediction annotations must have set all name or feature ids. " - "Otherwise there is no key to match on. Please update.") + "Otherwise there is no key to match on. Please update." + ) def all_have_key(features: List[FeatureSchema]) -> Tuple[bool, bool]: @@ -79,10 +88,9 @@ def all_have_key(features: List[FeatureSchema]) -> Tuple[bool, bool]: return all_schemas, all_names -def get_label_pairs(labels_a: list, - labels_b: list, - match_on="uid", - filter_mismatch=False) -> Dict[str, Tuple[Label, Label]]: +def get_label_pairs( + labels_a: list, labels_b: list, match_on="uid", filter_mismatch=False +) -> Dict[str, Tuple[Label, Label]]: """ This is a function to pairing a list of prediction labels and a list of ground truth labels easier. There are a few potentiall problems with this function. @@ -101,7 +109,7 @@ def get_label_pairs(labels_a: list, """ - if match_on not in ['uid', 'external_id']: + if match_on not in ["uid", "external_id"]: raise ValueError("Can only match on `uid` or `exteranl_id`.") label_lookup_a = { @@ -147,9 +155,10 @@ def get_feature_pairs( """ identifying_key = get_identifying_key(features_a, features_b) - lookup_a, lookup_b = _create_feature_lookup( - features_a, - identifying_key), _create_feature_lookup(features_b, identifying_key) + lookup_a, lookup_b = ( + _create_feature_lookup(features_a, identifying_key), + _create_feature_lookup(features_b, identifying_key), + ) keys = set(lookup_a.keys()).union(set(lookup_b.keys())) result = defaultdict(list) @@ -158,8 +167,9 @@ def get_feature_pairs( return result -def _create_feature_lookup(features: List[FeatureSchema], - key: str) -> Dict[str, List[FeatureSchema]]: +def _create_feature_lookup( + features: List[FeatureSchema], key: str +) -> Dict[str, List[FeatureSchema]]: """ Groups annotation by name (if available otherwise feature schema id). @@ -172,29 +182,33 @@ def _create_feature_lookup(features: List[FeatureSchema], grouped_features = defaultdict(list) for feature in features: if isinstance(feature, ClassificationAnnotation): - #checklists + # checklists if isinstance(feature.value, Checklist): for answer in feature.value.answer: new_answer = Radio(answer=answer) new_annotation = ClassificationAnnotation( value=new_answer, name=answer.name, - feature_schema_id=answer.feature_schema_id) + feature_schema_id=answer.feature_schema_id, + ) - grouped_features[getattr(answer, - key)].append(new_annotation) + grouped_features[getattr(answer, key)].append( + new_annotation + ) elif isinstance(feature.value, Text): grouped_features[getattr(feature, key)].append(feature) else: - grouped_features[getattr(feature.value.answer, - key)].append(feature) + grouped_features[getattr(feature.value.answer, key)].append( + feature + ) else: grouped_features[getattr(feature, key)].append(feature) return grouped_features -def has_no_matching_annotations(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation]): +def has_no_matching_annotations( + ground_truths: List[ObjectAnnotation], predictions: List[ObjectAnnotation] +): if len(ground_truths) and not len(predictions): # No existing predictions but existing ground truths means no matches. return True @@ -204,6 +218,7 @@ def has_no_matching_annotations(ground_truths: List[ObjectAnnotation], return False -def has_no_annotations(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation]): +def has_no_annotations( + ground_truths: List[ObjectAnnotation], predictions: List[ObjectAnnotation] +): return not len(ground_truths) and not len(predictions) diff --git a/libs/labelbox/src/labelbox/data/metrics/iou/calculation.py b/libs/labelbox/src/labelbox/data/metrics/iou/calculation.py index e25035c1b..2a376d3fe 100644 --- a/libs/labelbox/src/labelbox/data/metrics/iou/calculation.py +++ b/libs/labelbox/src/labelbox/data/metrics/iou/calculation.py @@ -4,15 +4,32 @@ import numpy as np from shapely.geometry import Polygon -from ..group import get_feature_pairs, get_identifying_key, has_no_annotations, has_no_matching_annotations -from ...annotation_types import (ObjectAnnotation, ClassificationAnnotation, - Mask, Geometry, Point, Line, Checklist, Text, - TextEntity, Radio, ScalarMetricValue) - - -def miou(ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], - include_subclasses: bool) -> Optional[ScalarMetricValue]: +from ..group import ( + get_feature_pairs, + get_identifying_key, + has_no_annotations, + has_no_matching_annotations, +) +from ...annotation_types import ( + ObjectAnnotation, + ClassificationAnnotation, + Mask, + Geometry, + Point, + Line, + Checklist, + Text, + TextEntity, + Radio, + ScalarMetricValue, +) + + +def miou( + ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], + predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], + include_subclasses: bool, +) -> Optional[ScalarMetricValue]: """ Computes miou for an arbitrary set of ground truth and predicted annotations. It first computes the iou for each metric and then takes the average (weighting each class equally) @@ -35,11 +52,11 @@ def miou(ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], return None if not len(ious) else np.mean(ious) -def feature_miou(ground_truths: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - include_subclasses: bool) -> Optional[ScalarMetricValue]: +def feature_miou( + ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], + predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], + include_subclasses: bool, +) -> Optional[ScalarMetricValue]: """ Computes iou score for all features of the same class. @@ -52,7 +69,7 @@ def feature_miou(ground_truths: List[Union[ObjectAnnotation, float representing the iou score for the feature type if score can be computed otherwise None. """ if has_no_matching_annotations(ground_truths, predictions): - return 0. + return 0.0 elif has_no_annotations(ground_truths, predictions): return None elif isinstance(predictions[0].value, Mask): @@ -65,13 +82,16 @@ def feature_miou(ground_truths: List[Union[ObjectAnnotation, return ner_miou(ground_truths, predictions, include_subclasses) else: raise ValueError( - f"Unexpected annotation found. Found {type(predictions[0].value)}") + f"Unexpected annotation found. Found {type(predictions[0].value)}" + ) -def vector_miou(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation], - include_subclasses: bool, - buffer=70.) -> Optional[ScalarMetricValue]: +def vector_miou( + ground_truths: List[ObjectAnnotation], + predictions: List[ObjectAnnotation], + include_subclasses: bool, + buffer=70.0, +) -> Optional[ScalarMetricValue]: """ Computes iou score for all features with the same feature schema id. Calculation includes subclassifications. @@ -84,44 +104,57 @@ def vector_miou(ground_truths: List[ObjectAnnotation], If there are no matches then this returns none """ if has_no_matching_annotations(ground_truths, predictions): - return 0. + return 0.0 elif has_no_annotations(ground_truths, predictions): return None pairs = _get_vector_pairs(ground_truths, predictions, buffer=buffer) return object_pair_miou(pairs, include_subclasses) -def object_pair_miou(pairs: List[Tuple[ObjectAnnotation, ObjectAnnotation, - ScalarMetricValue]], - include_subclasses) -> ScalarMetricValue: +def object_pair_miou( + pairs: List[Tuple[ObjectAnnotation, ObjectAnnotation, ScalarMetricValue]], + include_subclasses, +) -> ScalarMetricValue: pairs.sort(key=lambda triplet: triplet[2], reverse=True) solution_agreements = [] solution_features = set() all_features = set() for prediction, ground_truth, agreement in pairs: all_features.update({id(prediction), id(ground_truth)}) - if id(prediction) not in solution_features and id( - ground_truth) not in solution_features: + if ( + id(prediction) not in solution_features + and id(ground_truth) not in solution_features + ): solution_features.update({id(prediction), id(ground_truth)}) if include_subclasses: - classification_iou = miou(prediction.classifications, - ground_truth.classifications, - include_subclasses=False) - classification_iou = classification_iou if classification_iou is not None else agreement + classification_iou = miou( + prediction.classifications, + ground_truth.classifications, + include_subclasses=False, + ) + classification_iou = ( + classification_iou + if classification_iou is not None + else agreement + ) solution_agreements.append( - (agreement + classification_iou) / 2.) + (agreement + classification_iou) / 2.0 + ) else: solution_agreements.append(agreement) # Add zeros for unmatched Features - solution_agreements.extend([0.0] * - (len(all_features) - len(solution_features))) + solution_agreements.extend( + [0.0] * (len(all_features) - len(solution_features)) + ) return np.mean(solution_agreements) -def mask_miou(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation], - include_subclasses: bool) -> Optional[ScalarMetricValue]: +def mask_miou( + ground_truths: List[ObjectAnnotation], + predictions: List[ObjectAnnotation], + include_subclasses: bool, +) -> Optional[ScalarMetricValue]: """ Computes iou score for all features with the same feature schema id. Calculation includes subclassifications. @@ -133,7 +166,7 @@ def mask_miou(ground_truths: List[ObjectAnnotation], float representing the iou score for the masks """ if has_no_matching_annotations(ground_truths, predictions): - return 0. + return 0.0 elif has_no_annotations(ground_truths, predictions): return None @@ -141,22 +174,26 @@ def mask_miou(ground_truths: List[ObjectAnnotation], pairs = _get_mask_pairs(ground_truths, predictions) return object_pair_miou(pairs, include_subclasses=include_subclasses) - prediction_np = np.max([pred.value.draw(color=1) for pred in predictions], - axis=0) + prediction_np = np.max( + [pred.value.draw(color=1) for pred in predictions], axis=0 + ) ground_truth_np = np.max( [ground_truth.value.draw(color=1) for ground_truth in ground_truths], - axis=0) + axis=0, + ) if prediction_np.shape != ground_truth_np.shape: raise ValueError( "Prediction and mask must have the same shape." - f" Found {prediction_np.shape}/{ground_truth_np.shape}.") + f" Found {prediction_np.shape}/{ground_truth_np.shape}." + ) return _mask_iou(ground_truth_np, prediction_np) def classification_miou( - ground_truths: List[ClassificationAnnotation], - predictions: List[ClassificationAnnotation]) -> ScalarMetricValue: + ground_truths: List[ClassificationAnnotation], + predictions: List[ClassificationAnnotation], +) -> ScalarMetricValue: """ Computes iou score for all features with the same feature schema id. @@ -168,14 +205,15 @@ def classification_miou( """ if len(predictions) != len(ground_truths) != 1: - return 0. + return 0.0 prediction, ground_truth = predictions[0], ground_truths[0] if type(prediction) != type(ground_truth): raise TypeError( "Classification features must be the same type to compute agreement. " - f"Found `{type(prediction)}` and `{type(ground_truth)}`") + f"Found `{type(prediction)}` and `{type(ground_truth)}`" + ) if isinstance(prediction.value, Text): return text_iou(ground_truth.value, prediction.value) @@ -193,7 +231,8 @@ def radio_iou(ground_truth: Radio, prediction: Radio) -> ScalarMetricValue: """ key = get_identifying_key([prediction.answer], [ground_truth.answer]) return float( - getattr(prediction.answer, key) == getattr(ground_truth.answer, key)) + getattr(prediction.answer, key) == getattr(ground_truth.answer, key) + ) def text_iou(ground_truth: Text, prediction: Text) -> ScalarMetricValue: @@ -203,8 +242,9 @@ def text_iou(ground_truth: Text, prediction: Text) -> ScalarMetricValue: return float(prediction.answer == ground_truth.answer) -def checklist_iou(ground_truth: Checklist, - prediction: Checklist) -> ScalarMetricValue: +def checklist_iou( + ground_truth: Checklist, prediction: Checklist +) -> ScalarMetricValue: """ Calculates agreement between ground truth and predicted checklist items """ @@ -212,13 +252,15 @@ def checklist_iou(ground_truth: Checklist, schema_ids_pred = {getattr(answer, key) for answer in prediction.answer} schema_ids_label = {getattr(answer, key) for answer in ground_truth.answer} return float( - len(schema_ids_label & schema_ids_pred) / - len(schema_ids_label | schema_ids_pred)) + len(schema_ids_label & schema_ids_pred) + / len(schema_ids_label | schema_ids_pred) + ) def _get_vector_pairs( - ground_truths: List[ObjectAnnotation], predictions: List[ObjectAnnotation], - buffer: float + ground_truths: List[ObjectAnnotation], + predictions: List[ObjectAnnotation], + buffer: float, ) -> List[Tuple[ObjectAnnotation, ObjectAnnotation, ScalarMetricValue]]: """ # Get iou score for all pairs of ground truths and predictions @@ -226,14 +268,17 @@ def _get_vector_pairs( pairs = [] for ground_truth, prediction in product(ground_truths, predictions): if isinstance(prediction.value, Geometry) and isinstance( - ground_truth.value, Geometry): + ground_truth.value, Geometry + ): if isinstance(prediction.value, (Line, Point)): - - score = _polygon_iou(prediction.value.shapely.buffer(buffer), - ground_truth.value.shapely.buffer(buffer)) + score = _polygon_iou( + prediction.value.shapely.buffer(buffer), + ground_truth.value.shapely.buffer(buffer), + ) else: - score = _polygon_iou(prediction.value.shapely, - ground_truth.value.shapely) + score = _polygon_iou( + prediction.value.shapely, ground_truth.value.shapely + ) pairs.append((ground_truth, prediction, score)) return pairs @@ -247,9 +292,11 @@ def _get_mask_pairs( pairs = [] for ground_truth, prediction in product(ground_truths, predictions): if isinstance(prediction.value, Mask) and isinstance( - ground_truth.value, Mask): - score = _mask_iou(prediction.value.draw(color=1), - ground_truth.value.draw(color=1)) + ground_truth.value, Mask + ): + score = _mask_iou( + prediction.value.draw(color=1), ground_truth.value.draw(color=1) + ) pairs.append((ground_truth, prediction, score)) return pairs @@ -259,7 +306,7 @@ def _polygon_iou(poly1: Polygon, poly2: Polygon) -> ScalarMetricValue: poly1, poly2 = _ensure_valid_poly(poly1), _ensure_valid_poly(poly2) if poly1.intersects(poly2): return poly1.intersection(poly2).area / poly1.union(poly2).area - return 0. + return 0.0 def _ensure_valid_poly(poly): @@ -286,22 +333,28 @@ def _get_ner_pairs( def _ner_iou(ner1: TextEntity, ner2: TextEntity): """Computes iou between two text entity annotations""" - intersection_start, intersection_end = max(ner1.start, ner2.start), min( - ner1.end, ner2.end) - union_start, union_end = min(ner1.start, - ner2.start), max(ner1.end, ner2.end) - #edge case of only one character in text + intersection_start, intersection_end = ( + max(ner1.start, ner2.start), + min(ner1.end, ner2.end), + ) + union_start, union_end = ( + min(ner1.start, ner2.start), + max(ner1.end, ner2.end), + ) + # edge case of only one character in text if union_start == union_end: return 1 - #if there is no intersection + # if there is no intersection if intersection_start > intersection_end: return 0 return (intersection_end - intersection_start) / (union_end - union_start) -def ner_miou(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation], - include_subclasses: bool) -> Optional[ScalarMetricValue]: +def ner_miou( + ground_truths: List[ObjectAnnotation], + predictions: List[ObjectAnnotation], + include_subclasses: bool, +) -> Optional[ScalarMetricValue]: """ Computes iou score for all features with the same feature schema id. Calculation includes subclassifications. @@ -314,8 +367,8 @@ def ner_miou(ground_truths: List[ObjectAnnotation], If there are no matches then this returns none """ if has_no_matching_annotations(ground_truths, predictions): - return 0. + return 0.0 elif has_no_annotations(ground_truths, predictions): return None pairs = _get_ner_pairs(ground_truths, predictions) - return object_pair_miou(pairs, include_subclasses) \ No newline at end of file + return object_pair_miou(pairs, include_subclasses) diff --git a/libs/labelbox/src/labelbox/data/metrics/iou/iou.py b/libs/labelbox/src/labelbox/data/metrics/iou/iou.py index 357dc5dc9..9b0ce2695 100644 --- a/libs/labelbox/src/labelbox/data/metrics/iou/iou.py +++ b/libs/labelbox/src/labelbox/data/metrics/iou/iou.py @@ -1,19 +1,22 @@ # type: ignore from labelbox.data.annotation_types.metrics.scalar import ScalarMetric from typing import List, Optional, Union -from ...annotation_types import (Label, ObjectAnnotation, - ClassificationAnnotation) +from ...annotation_types import ( + Label, + ObjectAnnotation, + ClassificationAnnotation, +) from ..group import get_feature_pairs from .calculation import feature_miou from .calculation import miou -def miou_metric(ground_truths: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - include_subclasses=False) -> List[ScalarMetric]: +def miou_metric( + ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], + predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], + include_subclasses=False, +) -> List[ScalarMetric]: """ Computes miou between two sets of annotations. These annotations should relate to the same data (image/video). @@ -34,11 +37,11 @@ def miou_metric(ground_truths: List[Union[ObjectAnnotation, return [ScalarMetric(metric_name="custom_iou", value=iou)] -def feature_miou_metric(ground_truths: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - include_subclasses=True) -> List[ScalarMetric]: +def feature_miou_metric( + ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], + predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], + include_subclasses=True, +) -> List[ScalarMetric]: """ Computes the miou for each type of class in the list of annotations. These annotations should relate to the same data (image/video). @@ -56,21 +59,24 @@ def feature_miou_metric(ground_truths: List[Union[ObjectAnnotation, annotation_pairs = get_feature_pairs(predictions, ground_truths) metrics = [] for key in annotation_pairs: - - value = feature_miou(annotation_pairs[key][0], annotation_pairs[key][1], - include_subclasses) + value = feature_miou( + annotation_pairs[key][0], + annotation_pairs[key][1], + include_subclasses, + ) if value is None: continue metrics.append( - ScalarMetric(metric_name="custom_iou", - feature_name=key, - value=value)) + ScalarMetric( + metric_name="custom_iou", feature_name=key, value=value + ) + ) return metrics -def data_row_miou(ground_truth: Label, - prediction: Label, - include_subclasses=False) -> Optional[float]: +def data_row_miou( + ground_truth: Label, prediction: Label, include_subclasses=False +) -> Optional[float]: """ This function is no longer supported. Use miou() for raw values or miou_metric() for the metric @@ -84,5 +90,6 @@ def data_row_miou(ground_truth: Label, float indicating the iou score for this data row. Returns None if there are no annotations in ground_truth or prediction Labels """ - return miou(ground_truth.annotations, prediction.annotations, - include_subclasses) + return miou( + ground_truth.annotations, prediction.annotations, include_subclasses + ) diff --git a/libs/labelbox/src/labelbox/data/mixins.py b/libs/labelbox/src/labelbox/data/mixins.py index d8bc78de0..4440c8a72 100644 --- a/libs/labelbox/src/labelbox/data/mixins.py +++ b/libs/labelbox/src/labelbox/data/mixins.py @@ -2,7 +2,10 @@ from pydantic import BaseModel, field_validator, model_serializer -from labelbox.exceptions import ConfidenceNotSupportedException, CustomMetricsNotSupportedException +from labelbox.exceptions import ( + ConfidenceNotSupportedException, + CustomMetricsNotSupportedException, +) from warnings import warn @@ -20,11 +23,11 @@ def confidence_valid_float(cls, value): class ConfidenceNotSupportedMixin: - def __new__(cls, *args, **kwargs): if "confidence" in kwargs: raise ConfidenceNotSupportedException( - "Confidence is not supported for this annotation type yet") + "Confidence is not supported for this annotation type yet" + ) return super().__new__(cls) @@ -50,9 +53,9 @@ class CustomMetricsMixin(BaseModel): class CustomMetricsNotSupportedMixin: - def __new__(cls, *args, **kwargs): if "custom_metrics" in kwargs: raise CustomMetricsNotSupportedException( - "Custom metrics is not supported for this annotation type yet") + "Custom metrics is not supported for this annotation type yet" + ) return super().__new__(cls) diff --git a/libs/labelbox/src/labelbox/data/ontology.py b/libs/labelbox/src/labelbox/data/ontology.py index f19208873..4d2e66e95 100644 --- a/libs/labelbox/src/labelbox/data/ontology.py +++ b/libs/labelbox/src/labelbox/data/ontology.py @@ -1,13 +1,23 @@ from typing import Dict, List, Tuple, Union from labelbox.schema import ontology -from .annotation_types import (Text, Checklist, Radio, - ClassificationAnnotation, ObjectAnnotation, Mask, - Point, Line, Polygon, Rectangle, TextEntity) +from .annotation_types import ( + Text, + Checklist, + Radio, + ClassificationAnnotation, + ObjectAnnotation, + Mask, + Point, + Line, + Polygon, + Rectangle, + TextEntity, +) def get_feature_schema_lookup( - ontology_builder: ontology.OntologyBuilder + ontology_builder: ontology.OntologyBuilder, ) -> Tuple[Dict[str, str], Dict[str, str]]: tool_lookup = {} classification_lookup = {} @@ -19,11 +29,13 @@ def flatten_classification(classifications): f"feature_schema_id cannot be None for classification `{classification.name}`." ) if isinstance(classification, ontology.Classification): - classification_lookup[ - classification.name] = classification.feature_schema_id + classification_lookup[classification.name] = ( + classification.feature_schema_id + ) elif isinstance(classification, ontology.Option): - classification_lookup[ - classification.value] = classification.feature_schema_id + classification_lookup[classification.value] = ( + classification.feature_schema_id + ) else: raise TypeError( f"Unexpected type found in ontology. `{type(classification)}`" @@ -33,15 +45,18 @@ def flatten_classification(classifications): for tool in ontology_builder.tools: if tool.feature_schema_id is None: raise ValueError( - f"feature_schema_id cannot be None for tool `{tool.name}`.") + f"feature_schema_id cannot be None for tool `{tool.name}`." + ) tool_lookup[tool.name] = tool.feature_schema_id flatten_classification(tool.classifications) flatten_classification(ontology_builder.classifications) return tool_lookup, classification_lookup -def _get_options(annotation: ClassificationAnnotation, - existing_options: List[ontology.Option]): +def _get_options( + annotation: ClassificationAnnotation, + existing_options: List[ontology.Option], +): if isinstance(annotation.value, Radio): answers = [annotation.value.answer] elif isinstance(annotation.value, Text): @@ -63,7 +78,7 @@ def _get_options(annotation: ClassificationAnnotation, def get_classifications( annotations: List[ClassificationAnnotation], - existing_classifications: List[ontology.Classification] + existing_classifications: List[ontology.Classification], ) -> List[ontology.Classification]: existing_classifications = { classification.name: classification @@ -74,37 +89,45 @@ def get_classifications( classification_feature = existing_classifications.get(annotation.name) if classification_feature: classification_feature.options = _get_options( - annotation, classification_feature.options) + annotation, classification_feature.options + ) elif annotation.name not in existing_classifications: existing_classifications[annotation.name] = ontology.Classification( class_type=classification_mapping(annotation), name=annotation.name, - options=_get_options(annotation, [])) + options=_get_options(annotation, []), + ) return list(existing_classifications.values()) def get_tools( - annotations: List[ObjectAnnotation], - existing_tools: List[ontology.Classification]) -> List[ontology.Tool]: + annotations: List[ObjectAnnotation], + existing_tools: List[ontology.Classification], +) -> List[ontology.Tool]: existing_tools = {tool.name: tool for tool in existing_tools} for annotation in annotations: if annotation.name in existing_tools: # We just want to update classifications existing_tools[ - annotation.name].classifications = get_classifications( - annotation.classifications, - existing_tools[annotation.name].classifications) + annotation.name + ].classifications = get_classifications( + annotation.classifications, + existing_tools[annotation.name].classifications, + ) else: existing_tools[annotation.name] = ontology.Tool( tool=tool_mapping(annotation), name=annotation.name, - classifications=get_classifications(annotation.classifications, - [])) + classifications=get_classifications( + annotation.classifications, [] + ), + ) return list(existing_tools.values()) def tool_mapping( - annotation) -> Union[Mask, Polygon, Point, Rectangle, Line, TextEntity]: + annotation, +) -> Union[Mask, Polygon, Point, Rectangle, Line, TextEntity]: tool_types = ontology.Tool.Type mapping = { Mask: tool_types.SEGMENTATION, @@ -122,8 +145,7 @@ def tool_mapping( return result -def classification_mapping( - annotation) -> Union[Text, Checklist, Radio]: +def classification_mapping(annotation) -> Union[Text, Checklist, Radio]: classification_types = ontology.Classification.Type mapping = { Text: classification_types.TEXT, diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/annotation.py b/libs/labelbox/src/labelbox/data/serialization/coco/annotation.py index a0292e537..e387cb7d9 100644 --- a/libs/labelbox/src/labelbox/data/serialization/coco/annotation.py +++ b/libs/labelbox/src/labelbox/data/serialization/coco/annotation.py @@ -8,7 +8,9 @@ from ...annotation_types.metrics.scalar import ScalarMetric from ...annotation_types.video import VideoMaskAnnotation from ...annotation_types.annotation import ObjectAnnotation -from ...annotation_types.classification.classification import ClassificationAnnotation +from ...annotation_types.classification.classification import ( + ClassificationAnnotation, +) import numpy as np @@ -19,8 +21,9 @@ def rle_decoding(rle_arr: List[int], w: int, h: int) -> np.ndarray: indices = [] for idx, cnt in zip(rle_arr[0::2], rle_arr[1::2]): - indices.extend(list(range(idx - 1, - idx + cnt - 1))) # RLE is 1-based index + indices.extend( + list(range(idx - 1, idx + cnt - 1)) + ) # RLE is 1-based index mask = np.zeros(h * w, dtype=np.uint8) mask[indices] = 1 return mask.reshape((w, h)).T @@ -35,16 +38,18 @@ def get_annotation_lookup(annotations): annotation_lookup = defaultdict(list) for annotation in annotations: # Provide a default value of None if the attribute doesn't exist - attribute_value = getattr(annotation, 'image_id', None) or getattr(annotation, 'name', None) + attribute_value = getattr(annotation, "image_id", None) or getattr( + annotation, "name", None + ) annotation_lookup[attribute_value].append(annotation) - return annotation_lookup + return annotation_lookup class SegmentInfo(BaseModel): id: int category_id: int area: Union[float, int] - bbox: Tuple[float, float, float, float] #[x,y,w,h], + bbox: Tuple[float, float, float, float] # [x,y,w,h], iscrowd: int = 0 @@ -62,7 +67,7 @@ class COCOObjectAnnotation(BaseModel): category_id: int segmentation: Union[RLE, List[List[float]]] # [[x1,y1,x2,y2,x3,y3...]] area: float - bbox: Tuple[float, float, float, float] #[x,y,w,h], + bbox: Tuple[float, float, float, float] # [x,y,w,h], iscrowd: int = 0 diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/categories.py b/libs/labelbox/src/labelbox/data/serialization/coco/categories.py index 07ecacb03..60ba30fce 100644 --- a/libs/labelbox/src/labelbox/data/serialization/coco/categories.py +++ b/libs/labelbox/src/labelbox/data/serialization/coco/categories.py @@ -13,4 +13,5 @@ class Categories(BaseModel): def hash_category_name(name: str) -> int: return int.from_bytes( - md5(name.encode('utf-8')).hexdigest().encode('utf-8'), 'little') + md5(name.encode("utf-8")).hexdigest().encode("utf-8"), "little" + ) diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/converter.py b/libs/labelbox/src/labelbox/data/serialization/coco/converter.py index 1f6e8b178..e270b7573 100644 --- a/libs/labelbox/src/labelbox/data/serialization/coco/converter.py +++ b/libs/labelbox/src/labelbox/data/serialization/coco/converter.py @@ -8,8 +8,9 @@ from ...serialization.coco.panoptic_dataset import CocoPanopticDataset -def create_path_if_not_exists(path: Union[Path, str], - ignore_existing_data=False): +def create_path_if_not_exists( + path: Union[Path, str], ignore_existing_data=False +): path = Path(path) if not path.exists(): path.mkdir(parents=True, exist_ok=True) @@ -37,10 +38,12 @@ class COCOConverter: """ @staticmethod - def serialize_instances(labels: LabelCollection, - image_root: Union[Path, str], - ignore_existing_data=False, - max_workers=8) -> Dict[str, Any]: + def serialize_instances( + labels: LabelCollection, + image_root: Union[Path, str], + ignore_existing_data=False, + max_workers=8, + ) -> Dict[str, Any]: """ Convert a Labelbox LabelCollection into an mscoco dataset. This function will only convert masks, polygons, and rectangles. @@ -60,20 +63,23 @@ def serialize_instances(labels: LabelCollection, warnings.warn( "You are currently utilizing COCOconverter for this action, which will be deprecated in a later release.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) image_root = create_path_if_not_exists(image_root, ignore_existing_data) - return CocoInstanceDataset.from_common(labels=labels, - image_root=image_root, - max_workers=max_workers).model_dump() + return CocoInstanceDataset.from_common( + labels=labels, image_root=image_root, max_workers=max_workers + ).model_dump() @staticmethod - def serialize_panoptic(labels: LabelCollection, - image_root: Union[Path, str], - mask_root: Union[Path, str], - all_stuff: bool = False, - ignore_existing_data=False, - max_workers: int = 8) -> Dict[str, Any]: + def serialize_panoptic( + labels: LabelCollection, + image_root: Union[Path, str], + mask_root: Union[Path, str], + all_stuff: bool = False, + ignore_existing_data=False, + max_workers: int = 8, + ) -> Dict[str, Any]: """ Convert a Labelbox LabelCollection into an mscoco dataset. This function will only convert masks, polygons, and rectangles. @@ -96,20 +102,25 @@ def serialize_panoptic(labels: LabelCollection, warnings.warn( "You are currently utilizing COCOconverter for this action, which will be deprecated in a later release.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) image_root = create_path_if_not_exists(image_root, ignore_existing_data) mask_root = create_path_if_not_exists(mask_root, ignore_existing_data) - return CocoPanopticDataset.from_common(labels=labels, - image_root=image_root, - mask_root=mask_root, - all_stuff=all_stuff, - max_workers=max_workers).model_dump() + return CocoPanopticDataset.from_common( + labels=labels, + image_root=image_root, + mask_root=mask_root, + all_stuff=all_stuff, + max_workers=max_workers, + ).model_dump() @staticmethod - def deserialize_panoptic(json_data: Dict[str, Any], image_root: Union[Path, - str], - mask_root: Union[Path, str]) -> LabelGenerator: + def deserialize_panoptic( + json_data: Dict[str, Any], + image_root: Union[Path, str], + mask_root: Union[Path, str], + ) -> LabelGenerator: """ Convert coco panoptic data into the labelbox format (as a LabelGenerator). @@ -124,17 +135,19 @@ def deserialize_panoptic(json_data: Dict[str, Any], image_root: Union[Path, warnings.warn( "You are currently utilizing COCOconverter for this action, which will be deprecated in a later release.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) - image_root = validate_path(image_root, 'image_root') - mask_root = validate_path(mask_root, 'mask_root') + image_root = validate_path(image_root, "image_root") + mask_root = validate_path(mask_root, "mask_root") objs = CocoPanopticDataset(**json_data) gen = objs.to_common(image_root, mask_root) return LabelGenerator(data=gen) @staticmethod - def deserialize_instances(json_data: Dict[str, Any], - image_root: Path) -> LabelGenerator: + def deserialize_instances( + json_data: Dict[str, Any], image_root: Path + ) -> LabelGenerator: """ Convert coco object data into the labelbox format (as a LabelGenerator). @@ -148,9 +161,10 @@ def deserialize_instances(json_data: Dict[str, Any], warnings.warn( "You are currently utilizing COCOconverter for this action, which will be deprecated in a later release.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) - image_root = validate_path(image_root, 'image_root') + image_root = validate_path(image_root, "image_root") objs = CocoInstanceDataset(**json_data) gen = objs.to_common(image_root) return LabelGenerator(data=gen) diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/image.py b/libs/labelbox/src/labelbox/data/serialization/coco/image.py index 71029b936..cef173377 100644 --- a/libs/labelbox/src/labelbox/data/serialization/coco/image.py +++ b/libs/labelbox/src/labelbox/data/serialization/coco/image.py @@ -47,6 +47,6 @@ def id_to_rgb(id: int) -> Tuple[int, int, int]: def rgb_to_id(red: int, green: int, blue: int) -> int: id = blue * 256 * 256 - id += (green * 256) + id += green * 256 id += red return id diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/instance_dataset.py b/libs/labelbox/src/labelbox/data/serialization/coco/instance_dataset.py index 7cade81a1..5241e596f 100644 --- a/libs/labelbox/src/labelbox/data/serialization/coco/instance_dataset.py +++ b/libs/labelbox/src/labelbox/data/serialization/coco/instance_dataset.py @@ -7,17 +7,34 @@ import numpy as np from tqdm import tqdm -from ...annotation_types import ImageData, MaskData, Mask, ObjectAnnotation, Label, Polygon, Point, Rectangle +from ...annotation_types import ( + ImageData, + MaskData, + Mask, + ObjectAnnotation, + Label, + Polygon, + Point, + Rectangle, +) from ...annotation_types.collection import LabelCollection from .categories import Categories, hash_category_name -from .annotation import COCOObjectAnnotation, RLE, get_annotation_lookup, rle_decoding +from .annotation import ( + COCOObjectAnnotation, + RLE, + get_annotation_lookup, + rle_decoding, +) from .image import CocoImage, get_image, get_image_id from pydantic import BaseModel def mask_to_coco_object_annotation( - annotation: ObjectAnnotation, annot_idx: int, image_id: int, - category_id: int) -> Optional[COCOObjectAnnotation]: + annotation: ObjectAnnotation, + annot_idx: int, + image_id: int, + category_id: int, +) -> Optional[COCOObjectAnnotation]: # This is going to fill any holes into the multipolygon # If you need to support holes use the panoptic data format shapely = annotation.value.shapely.simplify(1).buffer(0) @@ -38,12 +55,16 @@ def mask_to_coco_object_annotation( ], area=area, bbox=[xmin, ymin, xmax - xmin, ymax - ymin], - iscrowd=0) + iscrowd=0, + ) -def vector_to_coco_object_annotation(annotation: ObjectAnnotation, - annot_idx: int, image_id: int, - category_id: int) -> COCOObjectAnnotation: +def vector_to_coco_object_annotation( + annotation: ObjectAnnotation, + annot_idx: int, + image_id: int, + category_id: int, +) -> COCOObjectAnnotation: shapely = annotation.value.shapely xmin, ymin, xmax, ymax = shapely.bounds segmentation = [] @@ -52,61 +73,83 @@ def vector_to_coco_object_annotation(annotation: ObjectAnnotation, segmentation.extend([point.x, point.y]) else: box = annotation.value - segmentation.extend([ - box.start.x, box.start.y, box.end.x, box.start.y, box.end.x, - box.end.y, box.start.x, box.end.y - ]) - - return COCOObjectAnnotation(id=annot_idx, - image_id=image_id, - category_id=category_id, - segmentation=[segmentation], - area=shapely.area, - bbox=[xmin, ymin, xmax - xmin, ymax - ymin], - iscrowd=0) - - -def rle_to_common(class_annotations: COCOObjectAnnotation, - class_name: str) -> ObjectAnnotation: - mask = rle_decoding(class_annotations.segmentation.counts, - *class_annotations.segmentation.size[::-1]) - return ObjectAnnotation(name=class_name, - value=Mask(mask=MaskData.from_2D_arr(mask), - color=[1, 1, 1])) - - -def segmentations_to_common(class_annotations: COCOObjectAnnotation, - class_name: str) -> List[ObjectAnnotation]: + segmentation.extend( + [ + box.start.x, + box.start.y, + box.end.x, + box.start.y, + box.end.x, + box.end.y, + box.start.x, + box.end.y, + ] + ) + + return COCOObjectAnnotation( + id=annot_idx, + image_id=image_id, + category_id=category_id, + segmentation=[segmentation], + area=shapely.area, + bbox=[xmin, ymin, xmax - xmin, ymax - ymin], + iscrowd=0, + ) + + +def rle_to_common( + class_annotations: COCOObjectAnnotation, class_name: str +) -> ObjectAnnotation: + mask = rle_decoding( + class_annotations.segmentation.counts, + *class_annotations.segmentation.size[::-1], + ) + return ObjectAnnotation( + name=class_name, + value=Mask(mask=MaskData.from_2D_arr(mask), color=[1, 1, 1]), + ) + + +def segmentations_to_common( + class_annotations: COCOObjectAnnotation, class_name: str +) -> List[ObjectAnnotation]: # Technically it is polygons. But the key in coco is called segmentations.. annotations = [] for points in class_annotations.segmentation: annotations.append( - ObjectAnnotation(name=class_name, - value=Polygon(points=[ - Point(x=points[i], y=points[i + 1]) - for i in range(0, len(points), 2) - ]))) + ObjectAnnotation( + name=class_name, + value=Polygon( + points=[ + Point(x=points[i], y=points[i + 1]) + for i in range(0, len(points), 2) + ] + ), + ) + ) return annotations def object_annotation_to_coco( - annotation: ObjectAnnotation, annot_idx: int, image_id: int, - category_id: int) -> Optional[COCOObjectAnnotation]: + annotation: ObjectAnnotation, + annot_idx: int, + image_id: int, + category_id: int, +) -> Optional[COCOObjectAnnotation]: if isinstance(annotation.value, Mask): - return mask_to_coco_object_annotation(annotation, annot_idx, image_id, - category_id) + return mask_to_coco_object_annotation( + annotation, annot_idx, image_id, category_id + ) elif isinstance(annotation.value, (Polygon, Rectangle)): - return vector_to_coco_object_annotation(annotation, annot_idx, image_id, - category_id) + return vector_to_coco_object_annotation( + annotation, annot_idx, image_id, category_id + ) else: return None def process_label( - label: Label, - idx: int, - image_root: str, - max_annotations_per_image=10000 + label: Label, idx: int, image_root: str, max_annotations_per_image=10000 ) -> Tuple[np.ndarray, List[COCOObjectAnnotation], Dict[str, str]]: annot_idx = idx * max_annotations_per_image image_id = get_image_id(label, idx) @@ -117,9 +160,11 @@ def process_label( for class_name in annotation_lookup: for annotation in annotation_lookup[class_name]: category_id = categories.get(annotation.name) or hash_category_name( - annotation.name) - coco_annotation = object_annotation_to_coco(annotation, annot_idx, - image_id, category_id) + annotation.name + ) + coco_annotation = object_annotation_to_coco( + annotation, annot_idx, image_id, category_id + ) if coco_annotation is not None: coco_annotations.append(coco_annotation) if annotation.name not in categories: @@ -136,10 +181,9 @@ class CocoInstanceDataset(BaseModel): categories: List[Categories] @classmethod - def from_common(cls, - labels: LabelCollection, - image_root: Path, - max_workers=8): + def from_common( + cls, labels: LabelCollection, image_root: Path, max_workers=8 + ): all_coco_annotations = [] categories = {} images = [] @@ -156,7 +200,6 @@ def from_common(cls, future.result() for future in tqdm(as_completed(futures)) ] else: - results = [ process_label(label, idx, image_root) for idx, label in enumerate(labels) @@ -172,18 +215,23 @@ def from_common(cls, for idx, category_id in enumerate(coco_categories.values()) } categories = [ - Categories(id=category_mapping[idx], - name=name, - supercategory='all', - isthing=1) for name, idx in coco_categories.items() + Categories( + id=category_mapping[idx], + name=name, + supercategory="all", + isthing=1, + ) + for name, idx in coco_categories.items() ] for annot in all_coco_annotations: annot.category_id = category_mapping[annot.category_id] - return CocoInstanceDataset(info={'image_root': image_root}, - images=images, - annotations=all_coco_annotations, - categories=categories) + return CocoInstanceDataset( + info={"image_root": image_root}, + images=images, + annotations=all_coco_annotations, + categories=categories, + ) def to_common(self, image_root): category_lookup = { @@ -204,11 +252,15 @@ def to_common(self, image_root): if isinstance(class_annotations.segmentation, RLE): annotations.append( rle_to_common( - class_annotations, category_lookup[ - class_annotations.category_id].name)) + class_annotations, + category_lookup[class_annotations.category_id].name, + ) + ) elif isinstance(class_annotations.segmentation, list): annotations.extend( segmentations_to_common( - class_annotations, category_lookup[ - class_annotations.category_id].name)) + class_annotations, + category_lookup[class_annotations.category_id].name, + ) + ) yield Label(data=data, annotations=annotations) diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/panoptic_dataset.py b/libs/labelbox/src/labelbox/data/serialization/coco/panoptic_dataset.py index 4d6b9e2ef..cbb410548 100644 --- a/libs/labelbox/src/labelbox/data/serialization/coco/panoptic_dataset.py +++ b/libs/labelbox/src/labelbox/data/serialization/coco/panoptic_dataset.py @@ -18,29 +18,36 @@ from pydantic import BaseModel -def vector_to_coco_segment_info(canvas: np.ndarray, - annotation: ObjectAnnotation, - annotation_idx: int, image: CocoImage, - category_id: int): - +def vector_to_coco_segment_info( + canvas: np.ndarray, + annotation: ObjectAnnotation, + annotation_idx: int, + image: CocoImage, + category_id: int, +): shapely = annotation.value.shapely if shapely.is_empty: return xmin, ymin, xmax, ymax = shapely.bounds - canvas = annotation.value.draw(height=image.height, - width=image.width, - canvas=canvas, - color=id_to_rgb(annotation_idx)) - - return SegmentInfo(id=annotation_idx, - category_id=category_id, - area=shapely.area, - bbox=[xmin, ymin, xmax - xmin, ymax - ymin]), canvas - - -def mask_to_coco_segment_info(canvas: np.ndarray, annotation, - annotation_idx: int, category_id): + canvas = annotation.value.draw( + height=image.height, + width=image.width, + canvas=canvas, + color=id_to_rgb(annotation_idx), + ) + + return SegmentInfo( + id=annotation_idx, + category_id=category_id, + area=shapely.area, + bbox=[xmin, ymin, xmax - xmin, ymax - ymin], + ), canvas + + +def mask_to_coco_segment_info( + canvas: np.ndarray, annotation, annotation_idx: int, category_id +): color = id_to_rgb(annotation_idx) mask = annotation.value.draw(color=color) shapely = annotation.value.shapely @@ -49,17 +56,17 @@ def mask_to_coco_segment_info(canvas: np.ndarray, annotation, xmin, ymin, xmax, ymax = shapely.bounds canvas = np.where(canvas == (0, 0, 0), mask, canvas) - return SegmentInfo(id=annotation_idx, - category_id=category_id, - area=shapely.area, - bbox=[xmin, ymin, xmax - xmin, ymax - ymin]), canvas + return SegmentInfo( + id=annotation_idx, + category_id=category_id, + area=shapely.area, + bbox=[xmin, ymin, xmax - xmin, ymax - ymin], + ), canvas -def process_label(label: Label, - idx: Union[int, str], - image_root, - mask_root, - all_stuff=False): +def process_label( + label: Label, idx: Union[int, str], image_root, mask_root, all_stuff=False +): """ Masks become stuff Polygon and rectangle become thing @@ -78,8 +85,11 @@ def process_label(label: Label, categories[annotation.name] = hash_category_name(annotation.name) if isinstance(annotation.value, Mask): coco_segment_info = mask_to_coco_segment_info( - canvas, annotation, class_idx + 1, - categories[annotation.name]) + canvas, + annotation, + class_idx + 1, + categories[annotation.name], + ) if coco_segment_info is None: # Filter out empty masks @@ -96,7 +106,8 @@ def process_label(label: Label, annotation_idx=(class_idx if all_stuff else annotation_idx) + 1, image=image, - category_id=categories[annotation.name]) + category_id=categories[annotation.name], + ) if coco_vector_info is None: # Filter out empty annotations @@ -106,13 +117,19 @@ def process_label(label: Label, segments.append(segment) is_thing[annotation.name] = 1 - int(all_stuff) - mask_file = str(image.file_name).replace('.jpg', '.png') + mask_file = str(image.file_name).replace(".jpg", ".png") mask_file = Path(mask_root, mask_file) Image.fromarray(canvas.astype(np.uint8)).save(mask_file) - return image, PanopticAnnotation( - image_id=image_id, - file_name=Path(mask_file.name), - segments_info=segments), categories, is_thing + return ( + image, + PanopticAnnotation( + image_id=image_id, + file_name=Path(mask_file.name), + segments_info=segments, + ), + categories, + is_thing, + ) class CocoPanopticDataset(BaseModel): @@ -122,12 +139,14 @@ class CocoPanopticDataset(BaseModel): categories: List[Categories] @classmethod - def from_common(cls, - labels: LabelCollection, - image_root, - mask_root, - all_stuff, - max_workers=8): + def from_common( + cls, + labels: LabelCollection, + image_root, + mask_root, + all_stuff, + max_workers=8, + ): all_coco_annotations = [] coco_categories = {} coco_things = {} @@ -136,8 +155,15 @@ def from_common(cls, if max_workers: with ProcessPoolExecutor(max_workers=max_workers) as exc: futures = [ - exc.submit(process_label, label, idx, image_root, mask_root, - all_stuff) for idx, label in enumerate(labels) + exc.submit( + process_label, + label, + idx, + image_root, + mask_root, + all_stuff, + ) + for idx, label in enumerate(labels) ] results = [ future.result() for future in tqdm(as_completed(futures)) @@ -159,10 +185,12 @@ def from_common(cls, for idx, category_id in enumerate(coco_categories.values()) } categories = [ - Categories(id=category_mapping[idx], - name=name, - supercategory='all', - isthing=coco_things.get(name, 1)) + Categories( + id=category_mapping[idx], + name=name, + supercategory="all", + isthing=coco_things.get(name, 1), + ) for name, idx in coco_categories.items() ] @@ -170,13 +198,12 @@ def from_common(cls, for segment in annot.segments_info: segment.category_id = category_mapping[segment.category_id] - return CocoPanopticDataset(info={ - 'image_root': image_root, - 'mask_root': mask_root - }, - images=images, - annotations=all_coco_annotations, - categories=categories) + return CocoPanopticDataset( + info={"image_root": image_root, "mask_root": mask_root}, + images=images, + annotations=all_coco_annotations, + categories=categories, + ) def to_common(self, image_root: Path, mask_root: Path): category_lookup = { @@ -194,20 +221,22 @@ def to_common(self, image_root: Path, mask_root: Path): raise ValueError( f"Cannot find file {im_path}. Make sure `image_root` is set properly" ) - if not str(annotation.file_name).endswith('.png'): + if not str(annotation.file_name).endswith(".png"): raise ValueError( f"COCO masks must be stored as png files and their extension must be `.png`. Found {annotation.file_name}" ) mask = MaskData( - file_path=str(Path(mask_root, annotation.file_name))) + file_path=str(Path(mask_root, annotation.file_name)) + ) for segmentation in annotation.segments_info: category = category_lookup[segmentation.category_id] annotations.append( - ObjectAnnotation(name=category.name, - value=Mask(mask=mask, - color=id_to_rgb( - segmentation.id)))) + ObjectAnnotation( + name=category.name, + value=Mask(mask=mask, color=id_to_rgb(segmentation.id)), + ) + ) data = ImageData(file_path=str(im_path)) yield Label(data=data, annotations=annotations) del annotation_lookup[image.id] diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/path.py b/libs/labelbox/src/labelbox/data/serialization/coco/path.py index 8f6786655..c3be84f31 100644 --- a/libs/labelbox/src/labelbox/data/serialization/coco/path.py +++ b/libs/labelbox/src/labelbox/data/serialization/coco/path.py @@ -1,8 +1,8 @@ from pathlib import Path from pydantic import BaseModel, model_serializer -class PathSerializerMixin(BaseModel): +class PathSerializerMixin(BaseModel): @model_serializer(mode="wrap") def serialize_model(self, handler): res = handler(self) diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py index 602fa7628..8770222b9 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py @@ -9,20 +9,20 @@ subclass_registry = {} + class _SubclassRegistryBase(BaseModel): - model_config = ConfigDict(extra="allow") - + def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if cls.__name__ != "NDAnnotation": with threading.Lock(): - subclass_registry[cls.__name__] = cls + subclass_registry[cls.__name__] = cls + class DataRow(_CamelCaseMixin): id: Optional[str] = None global_key: Optional[str] = None - @model_validator(mode="after") def must_set_one(self): @@ -45,6 +45,8 @@ class NDAnnotation(NDJsonBase): @model_validator(mode="after") def must_set_one(self): - if (not hasattr(self, "schema_id") or self.schema_id is None) and (not hasattr(self, "name") or self.name is None): + if (not hasattr(self, "schema_id") or self.schema_id is None) and ( + not hasattr(self, "name") or self.name is None + ): raise ValueError("Schema id or name are not set. Set either one.") return self diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py index e655e9f36..f4bc7e528 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py @@ -1,15 +1,33 @@ from typing import Any, Dict, List, Union, Optional -from labelbox.data.mixins import ConfidenceMixin, CustomMetric, CustomMetricsMixin +from labelbox.data.mixins import ( + ConfidenceMixin, + CustomMetric, + CustomMetricsMixin, +) from labelbox.data.serialization.ndjson.base import DataRow, NDAnnotation from ...annotation_types.annotation import ClassificationAnnotation from ...annotation_types.video import VideoClassificationAnnotation -from ...annotation_types.llm_prompt_response.prompt import PromptClassificationAnnotation, PromptText -from ...annotation_types.classification.classification import ClassificationAnswer, Text, Checklist, Radio +from ...annotation_types.llm_prompt_response.prompt import ( + PromptClassificationAnnotation, + PromptText, +) +from ...annotation_types.classification.classification import ( + ClassificationAnswer, + Text, + Checklist, + Radio, +) from ...annotation_types.types import Cuid from ...annotation_types.data import TextData, VideoData, ImageData -from pydantic import model_validator, Field, BaseModel, ConfigDict, model_serializer +from pydantic import ( + model_validator, + Field, + BaseModel, + ConfigDict, + model_serializer, +) from pydantic.alias_generators import to_camel from .base import _SubclassRegistryBase @@ -17,24 +35,26 @@ class NDAnswer(ConfidenceMixin, CustomMetricsMixin): name: Optional[str] = None schema_id: Optional[Cuid] = None - classifications: Optional[List['NDSubclassificationType']] = None - model_config = ConfigDict(populate_by_name = True, alias_generator = to_camel) + classifications: Optional[List["NDSubclassificationType"]] = None + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) @model_validator(mode="after") def must_set_one(self): - if (not hasattr(self, "schema_id") or self.schema_id is None) and (not hasattr(self, "name") or self.name is None): + if (not hasattr(self, "schema_id") or self.schema_id is None) and ( + not hasattr(self, "name") or self.name is None + ): raise ValueError("Schema id or name are not set. Set either one.") return self @model_serializer(mode="wrap") def serialize_model(self, handler): res = handler(self) - if 'name' in res and res['name'] is None: - res.pop('name') - if 'schemaId' in res and res['schemaId'] is None: - res.pop('schemaId') + if "name" in res and res["name"] is None: + res.pop("name") + if "schemaId" in res and res["schemaId"] is None: + res.pop("schemaId") if self.classifications: - res['classifications'] = [ + res["classifications"] = [ c.model_dump(exclude_none=True) for c in self.classifications ] return res @@ -54,7 +74,7 @@ def serialize_model(self, handler): res = handler(self) # This means these are no video frames .. if self.frames is None: - res.pop('frames') + res.pop("frames") return res @@ -62,13 +82,16 @@ class NDTextSubclass(NDAnswer): answer: str def to_common(self) -> Text: - return Text(answer=self.answer, - confidence=self.confidence, - custom_metrics=self.custom_metrics) + return Text( + answer=self.answer, + confidence=self.confidence, + custom_metrics=self.custom_metrics, + ) @classmethod - def from_common(cls, text: Text, name: str, - feature_schema_id: Cuid) -> "NDTextSubclass": + def from_common( + cls, text: Text, name: str, feature_schema_id: Cuid + ) -> "NDTextSubclass": return cls( answer=text.answer, name=name, @@ -79,41 +102,56 @@ def from_common(cls, text: Text, name: str, class NDChecklistSubclass(NDAnswer): - answer: List[NDAnswer] = Field(..., validation_alias='answers') + answer: List[NDAnswer] = Field(..., validation_alias="answers") def to_common(self) -> Checklist: - - return Checklist(answer=[ - ClassificationAnswer(name=answer.name, - feature_schema_id=answer.schema_id, - confidence=answer.confidence, - classifications=[ - NDSubclassification.to_common(annot) - for annot in answer.classifications - ] if answer.classifications else None, - custom_metrics=answer.custom_metrics) - for answer in self.answer - ]) + return Checklist( + answer=[ + ClassificationAnswer( + name=answer.name, + feature_schema_id=answer.schema_id, + confidence=answer.confidence, + classifications=[ + NDSubclassification.to_common(annot) + for annot in answer.classifications + ] + if answer.classifications + else None, + custom_metrics=answer.custom_metrics, + ) + for answer in self.answer + ] + ) @classmethod - def from_common(cls, checklist: Checklist, name: str, - feature_schema_id: Cuid) -> "NDChecklistSubclass": - return cls(answer=[ - NDAnswer(name=answer.name, - schema_id=answer.feature_schema_id, - confidence=answer.confidence, - classifications=[NDSubclassification.from_common(annot) for annot in answer.classifications] if answer.classifications else None, - custom_metrics=answer.custom_metrics) - for answer in checklist.answer - ], - name=name, - schema_id=feature_schema_id) + def from_common( + cls, checklist: Checklist, name: str, feature_schema_id: Cuid + ) -> "NDChecklistSubclass": + return cls( + answer=[ + NDAnswer( + name=answer.name, + schema_id=answer.feature_schema_id, + confidence=answer.confidence, + classifications=[ + NDSubclassification.from_common(annot) + for annot in answer.classifications + ] + if answer.classifications + else None, + custom_metrics=answer.custom_metrics, + ) + for answer in checklist.answer + ], + name=name, + schema_id=feature_schema_id, + ) @model_serializer(mode="wrap") def serialize_model(self, handler): res = handler(self) - if 'answers' in res: - res['answer'] = res['answers'] + if "answers" in res: + res["answer"] = res["answers"] del res["answers"] return res @@ -122,42 +160,57 @@ class NDRadioSubclass(NDAnswer): answer: NDAnswer def to_common(self) -> Radio: - return Radio(answer=ClassificationAnswer( - name=self.answer.name, - feature_schema_id=self.answer.schema_id, - confidence=self.answer.confidence, - classifications=[ - NDSubclassification.to_common(annot) - for annot in self.answer.classifications - ] if self.answer.classifications else None, - custom_metrics=self.answer.custom_metrics)) + return Radio( + answer=ClassificationAnswer( + name=self.answer.name, + feature_schema_id=self.answer.schema_id, + confidence=self.answer.confidence, + classifications=[ + NDSubclassification.to_common(annot) + for annot in self.answer.classifications + ] + if self.answer.classifications + else None, + custom_metrics=self.answer.custom_metrics, + ) + ) @classmethod - def from_common(cls, radio: Radio, name: str, - feature_schema_id: Cuid) -> "NDRadioSubclass": - return cls(answer=NDAnswer(name=radio.answer.name, - schema_id=radio.answer.feature_schema_id, - confidence=radio.answer.confidence, - classifications=[ - NDSubclassification.from_common(annot) - for annot in radio.answer.classifications - ] if radio.answer.classifications else None, - custom_metrics=radio.answer.custom_metrics), - name=name, - schema_id=feature_schema_id) + def from_common( + cls, radio: Radio, name: str, feature_schema_id: Cuid + ) -> "NDRadioSubclass": + return cls( + answer=NDAnswer( + name=radio.answer.name, + schema_id=radio.answer.feature_schema_id, + confidence=radio.answer.confidence, + classifications=[ + NDSubclassification.from_common(annot) + for annot in radio.answer.classifications + ] + if radio.answer.classifications + else None, + custom_metrics=radio.answer.custom_metrics, + ), + name=name, + schema_id=feature_schema_id, + ) class NDPromptTextSubclass(NDAnswer): answer: str def to_common(self) -> PromptText: - return PromptText(answer=self.answer, - confidence=self.confidence, - custom_metrics=self.custom_metrics) + return PromptText( + answer=self.answer, + confidence=self.confidence, + custom_metrics=self.custom_metrics, + ) @classmethod - def from_common(cls, prompt_text: PromptText, name: str, - feature_schema_id: Cuid) -> "NDPromptTextSubclass": + def from_common( + cls, prompt_text: PromptText, name: str, feature_schema_id: Cuid + ) -> "NDPromptTextSubclass": return cls( answer=prompt_text.answer, name=name, @@ -171,17 +224,18 @@ def from_common(cls, prompt_text: PromptText, name: str, class NDText(NDAnnotation, NDTextSubclass, _SubclassRegistryBase): - @classmethod - def from_common(cls, - uuid: str, - text: Text, - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[TextData, ImageData], - message_id: str, - confidence: Optional[float] = None) -> "NDText": + def from_common( + cls, + uuid: str, + text: Text, + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + data: Union[TextData, ImageData], + message_id: str, + confidence: Optional[float] = None, + ) -> "NDText": return cls( answer=text.answer, data_row=DataRow(id=data.uid, global_key=data.global_key), @@ -194,8 +248,9 @@ def from_common(cls, ) -class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported, _SubclassRegistryBase): - +class NDChecklist( + NDAnnotation, NDChecklistSubclass, VideoSupported, _SubclassRegistryBase +): @model_serializer(mode="wrap") def serialize_model(self, handler): res = handler(self) @@ -205,40 +260,46 @@ def serialize_model(self, handler): @classmethod def from_common( - cls, - uuid: str, - checklist: Checklist, - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[VideoData, TextData, ImageData], - message_id: str, - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None + cls, + uuid: str, + checklist: Checklist, + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + data: Union[VideoData, TextData, ImageData], + message_id: str, + confidence: Optional[float] = None, + custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDChecklist": + return cls( + answer=[ + NDAnswer( + name=answer.name, + schema_id=answer.feature_schema_id, + confidence=answer.confidence, + classifications=[ + NDSubclassification.from_common(annot) + for annot in answer.classifications + ] + if answer.classifications + else None, + custom_metrics=answer.custom_metrics, + ) + for answer in checklist.answer + ], + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + frames=extra.get("frames"), + message_id=message_id, + confidence=confidence, + ) - return cls(answer=[ - NDAnswer(name=answer.name, - schema_id=answer.feature_schema_id, - confidence=answer.confidence, - classifications=[ - NDSubclassification.from_common(annot) - for annot in answer.classifications - ] if answer.classifications else None, - custom_metrics=answer.custom_metrics) - for answer in checklist.answer - ], - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - frames=extra.get('frames'), - message_id=message_id, - confidence=confidence) - - -class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported, _SubclassRegistryBase): +class NDRadio( + NDAnnotation, NDRadioSubclass, VideoSupported, _SubclassRegistryBase +): @classmethod def from_common( cls, @@ -251,32 +312,37 @@ def from_common( message_id: str, confidence: Optional[float] = None, ) -> "NDRadio": - return cls(answer=NDAnswer(name=radio.answer.name, - schema_id=radio.answer.feature_schema_id, - confidence=radio.answer.confidence, - classifications=[ - NDSubclassification.from_common(annot) - for annot in radio.answer.classifications - ] if radio.answer.classifications else None, - custom_metrics=radio.answer.custom_metrics), - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - frames=extra.get('frames'), - message_id=message_id, - confidence=confidence) - + return cls( + answer=NDAnswer( + name=radio.answer.name, + schema_id=radio.answer.feature_schema_id, + confidence=radio.answer.confidence, + classifications=[ + NDSubclassification.from_common(annot) + for annot in radio.answer.classifications + ] + if radio.answer.classifications + else None, + custom_metrics=radio.answer.custom_metrics, + ), + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + frames=extra.get("frames"), + message_id=message_id, + confidence=confidence, + ) + @model_serializer(mode="wrap") def serialize_model(self, handler): res = handler(self) if "classifications" in res and res["classifications"] == []: del res["classifications"] return res - - + + class NDPromptText(NDAnnotation, NDPromptTextSubclass, _SubclassRegistryBase): - @classmethod def from_common( cls, @@ -285,7 +351,7 @@ def from_common( name, data: Dict, feature_schema_id: Cuid, - confidence: Optional[float] = None + confidence: Optional[float] = None, ) -> "NDPromptText": return cls( answer=text.answer, @@ -294,11 +360,11 @@ def from_common( schema_id=feature_schema_id, uuid=uuid, confidence=text.confidence, - custom_metrics=text.custom_metrics) + custom_metrics=text.custom_metrics, + ) class NDSubclassification: - @classmethod def from_common( cls, annotation: ClassificationAnnotation @@ -308,19 +374,23 @@ def from_common( raise TypeError( f"Unable to convert object to MAL format. `{type(annotation.value)}`" ) - return classify_obj.from_common(annotation.value, annotation.name, - annotation.feature_schema_id) + return classify_obj.from_common( + annotation.value, annotation.name, annotation.feature_schema_id + ) @staticmethod def to_common( - annotation: "NDClassificationType") -> ClassificationAnnotation: - return ClassificationAnnotation(value=annotation.to_common(), - name=annotation.name, - feature_schema_id=annotation.schema_id) + annotation: "NDClassificationType", + ) -> ClassificationAnnotation: + return ClassificationAnnotation( + value=annotation.to_common(), + name=annotation.name, + feature_schema_id=annotation.schema_id, + ) @staticmethod def lookup_subclassification( - annotation: ClassificationAnnotation + annotation: ClassificationAnnotation, ) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]: return { Text: NDTextSubclass, @@ -330,69 +400,76 @@ def lookup_subclassification( class NDClassification: - @staticmethod def to_common( - annotation: "NDClassificationType" + annotation: "NDClassificationType", ) -> Union[ClassificationAnnotation, VideoClassificationAnnotation]: common = ClassificationAnnotation( value=annotation.to_common(), name=annotation.name, feature_schema_id=annotation.schema_id, - extra={'uuid': annotation.uuid}, + extra={"uuid": annotation.uuid}, message_id=annotation.message_id, confidence=annotation.confidence, ) - if getattr(annotation, 'frames', None) is None: + if getattr(annotation, "frames", None) is None: return [common] results = [] for frame in annotation.frames: for idx in range(frame.start, frame.end + 1, 1): results.append( - VideoClassificationAnnotation(frame=idx, **common.model_dump(exclude_none=True))) + VideoClassificationAnnotation( + frame=idx, **common.model_dump(exclude_none=True) + ) + ) return results @classmethod def from_common( - cls, annotation: Union[ClassificationAnnotation, - VideoClassificationAnnotation], - data: Union[VideoData, TextData, ImageData] + cls, + annotation: Union[ + ClassificationAnnotation, VideoClassificationAnnotation + ], + data: Union[VideoData, TextData, ImageData], ) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]: classify_obj = cls.lookup_classification(annotation) if classify_obj is None: raise TypeError( f"Unable to convert object to MAL format. `{type(annotation.value)}`" ) - return classify_obj.from_common(str(annotation._uuid), annotation.value, - annotation.name, - annotation.feature_schema_id, - annotation.extra, data, - annotation.message_id, - annotation.confidence) + return classify_obj.from_common( + str(annotation._uuid), + annotation.value, + annotation.name, + annotation.feature_schema_id, + annotation.extra, + data, + annotation.message_id, + annotation.confidence, + ) @staticmethod def lookup_classification( - annotation: Union[ClassificationAnnotation, - VideoClassificationAnnotation] + annotation: Union[ + ClassificationAnnotation, VideoClassificationAnnotation + ], ) -> Union[NDText, NDChecklist, NDRadio]: - return { - Text: NDText, - Checklist: NDChecklist, - Radio: NDRadio - }.get(type(annotation.value)) + return {Text: NDText, Checklist: NDChecklist, Radio: NDRadio}.get( + type(annotation.value) + ) -class NDPromptClassification: +class NDPromptClassification: @staticmethod def to_common( - annotation: "NDPromptClassificationType" + annotation: "NDPromptClassificationType", ) -> Union[PromptClassificationAnnotation]: common = PromptClassificationAnnotation( value=annotation, name=annotation.name, feature_schema_id=annotation.schema_id, - extra={'uuid': annotation.uuid}, + extra={"uuid": annotation.uuid}, confidence=annotation.confidence, ) @@ -400,20 +477,25 @@ def to_common( @classmethod def from_common( - cls, annotation: Union[PromptClassificationAnnotation], - data: Union[VideoData, TextData, ImageData] + cls, + annotation: Union[PromptClassificationAnnotation], + data: Union[VideoData, TextData, ImageData], ) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]: - return NDPromptText.from_common(str(annotation._uuid), annotation.value, - annotation.name, - data, - annotation.feature_schema_id, - annotation.confidence) + return NDPromptText.from_common( + str(annotation._uuid), + annotation.value, + annotation.name, + data, + annotation.feature_schema_id, + annotation.confidence, + ) # Make sure to keep NDChecklistSubclass prior to NDRadioSubclass in the list, # otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used -NDSubclassificationType = Union[NDChecklistSubclass, NDRadioSubclass, - NDTextSubclass] +NDSubclassificationType = Union[ + NDChecklistSubclass, NDRadioSubclass, NDTextSubclass +] NDAnswer.model_rebuild() NDChecklistSubclass.model_rebuild() @@ -427,4 +509,4 @@ def from_common( # Make sure to keep NDChecklist prior to NDRadio in the list, # otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used NDClassificationType = Union[NDChecklist, NDRadio, NDText] -NDPromptClassificationType = Union[NDPromptText] \ No newline at end of file +NDPromptClassificationType = Union[NDPromptText] diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py index a38247271..01ab8454a 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py @@ -6,9 +6,11 @@ from labelbox.data.annotation_types.annotation import ObjectAnnotation from labelbox.data.annotation_types.classification.classification import ( - ClassificationAnnotation,) + ClassificationAnnotation, +) from labelbox.data.annotation_types.metrics.confusion_matrix import ( - ConfusionMatrixMetric,) + ConfusionMatrixMetric, +) from labelbox.data.annotation_types.metrics.scalar import ScalarMetric from labelbox.data.annotation_types.video import VideoMaskAnnotation @@ -24,7 +26,6 @@ class NDJsonConverter: - @staticmethod def deserialize(json_data: Iterable[Dict[str, Any]]) -> LabelGenerator: """ @@ -41,7 +42,8 @@ def deserialize(json_data: Iterable[Dict[str, Any]]) -> LabelGenerator: @staticmethod def serialize( - labels: LabelCollection) -> Generator[Dict[str, Any], None, None]: + labels: LabelCollection, + ) -> Generator[Dict[str, Any], None, None]: """ Converts a labelbox common object to the labelbox ndjson format (prediction import format) @@ -56,8 +58,9 @@ def serialize( """ used_uuids: Set[uuid.UUID] = set() - relationship_uuids: Dict[uuid.UUID, - Deque[uuid.UUID]] = defaultdict(deque) + relationship_uuids: Dict[uuid.UUID, Deque[uuid.UUID]] = defaultdict( + deque + ) # UUIDs are private properties used to enhance UX when defining relationships. # They are created for all annotations, but only utilized for relationships. @@ -66,15 +69,17 @@ def serialize( # For relationship annotations, during first pass, we update the UUIDs of the source and target annotations. # During the second pass, we update the UUIDs of the annotations referenced by the relationship annotations. for label in labels: - uuid_safe_annotations: List[Union[ - ClassificationAnnotation, - ObjectAnnotation, - VideoMaskAnnotation, - ScalarMetric, - ConfusionMatrixMetric, - RelationshipAnnotation, - MessageEvaluationTaskAnnotation, - ]] = [] + uuid_safe_annotations: List[ + Union[ + ClassificationAnnotation, + ObjectAnnotation, + VideoMaskAnnotation, + ScalarMetric, + ConfusionMatrixMetric, + RelationshipAnnotation, + MessageEvaluationTaskAnnotation, + ] + ] = [] # First pass to get all RelationshipAnnotaitons # and update the UUIDs of the source and target annotations for annotation in label.annotations: @@ -83,9 +88,11 @@ def serialize( new_source_uuid = uuid.uuid4() new_target_uuid = uuid.uuid4() relationship_uuids[annotation.value.source._uuid].append( - new_source_uuid) + new_source_uuid + ) relationship_uuids[annotation.value.target._uuid].append( - new_target_uuid) + new_target_uuid + ) annotation.value.source._uuid = new_source_uuid annotation.value.target._uuid = new_target_uuid if annotation._uuid in used_uuids: @@ -94,8 +101,9 @@ def serialize( uuid_safe_annotations.append(annotation) # Second pass to update UUIDs for annotations referenced by RelationshipAnnotations for annotation in label.annotations: - if (not isinstance(annotation, RelationshipAnnotation) and - hasattr(annotation, "_uuid")): + if not isinstance( + annotation, RelationshipAnnotation + ) and hasattr(annotation, "_uuid"): annotation = copy.deepcopy(annotation) next_uuids = relationship_uuids[annotation._uuid] if len(next_uuids) > 0: @@ -119,6 +127,6 @@ def serialize( for k, v in list(res.items()): if k in IGNORE_IF_NONE and v is None: del res[k] - if getattr(label, 'is_benchmark_reference'): - res['isBenchmarkReferenceLabel'] = True + if getattr(label, "is_benchmark_reference"): + res["isBenchmarkReferenceLabel"] = True yield res diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py index b9e9f2456..18134a228 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py @@ -3,9 +3,15 @@ from typing import Dict, Generator, List, Tuple, Union from collections import defaultdict -from ...annotation_types.annotation import ClassificationAnnotation, ObjectAnnotation +from ...annotation_types.annotation import ( + ClassificationAnnotation, + ObjectAnnotation, +) from ...annotation_types.relationship import RelationshipAnnotation -from ...annotation_types.video import DICOMObjectAnnotation, VideoClassificationAnnotation +from ...annotation_types.video import ( + DICOMObjectAnnotation, + VideoClassificationAnnotation, +) from ...annotation_types.video import VideoObjectAnnotation, VideoMaskAnnotation from ...annotation_types.collection import LabelCollection, LabelGenerator from ...annotation_types.data import DicomData, ImageData, TextData, VideoData @@ -13,12 +19,29 @@ from ...annotation_types.label import Label from ...annotation_types.ner import TextEntity, ConversationEntity from ...annotation_types.metrics import ScalarMetric, ConfusionMatrixMetric -from ...annotation_types.llm_prompt_response.prompt import PromptClassificationAnnotation +from ...annotation_types.llm_prompt_response.prompt import ( + PromptClassificationAnnotation, +) from ...annotation_types.mmc import MessageEvaluationTaskAnnotation from .metric import NDScalarMetric, NDMetricAnnotation, NDConfusionMatrixMetric -from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass, NDPromptClassification, NDPromptClassificationType, NDPromptText -from .objects import NDObject, NDObjectType, NDSegments, NDDicomSegments, NDVideoMasks, NDDicomMasks +from .classification import ( + NDChecklistSubclass, + NDClassification, + NDClassificationType, + NDRadioSubclass, + NDPromptClassification, + NDPromptClassificationType, + NDPromptText, +) +from .objects import ( + NDObject, + NDObjectType, + NDSegments, + NDDicomSegments, + NDVideoMasks, + NDDicomMasks, +) from .mmc import NDMessageTask from .relationship import NDRelationship from .base import DataRow @@ -27,19 +50,29 @@ from pydantic_core import PydanticUndefined from contextlib import suppress -AnnotationType = Union[NDObjectType, NDClassificationType, NDPromptClassificationType, - NDConfusionMatrixMetric, NDScalarMetric, NDDicomSegments, - NDSegments, NDDicomMasks, NDVideoMasks, NDRelationship, - NDPromptText, NDMessageTask] +AnnotationType = Union[ + NDObjectType, + NDClassificationType, + NDPromptClassificationType, + NDConfusionMatrixMetric, + NDScalarMetric, + NDDicomSegments, + NDSegments, + NDDicomMasks, + NDVideoMasks, + NDRelationship, + NDPromptText, + NDMessageTask, +] class NDLabel(BaseModel): annotations: List[_SubclassRegistryBase] - + def __init__(self, **kwargs): # NOTE: Deserialization of subclasses in pydantic is difficult, see here https://blog.devgenius.io/deserialize-child-classes-with-pydantic-that-gonna-work-784230e1cf83 # Below implements the subclass registry as mentioned in the article. The python dicts we pass in can be missing certain fields - # we essentially have to infer the type against all sub classes that have the _SubclasssRegistryBase inheritance. + # we essentially have to infer the type against all sub classes that have the _SubclasssRegistryBase inheritance. # It works by checking if the keys of our annotations we are missing in matches any required subclass. # More keys are prioritized over less keys (closer match). This is used when importing json to our base models not a lot of customer workflows # depend on this method but this works for all our existing tests with the bonus of added validation. (no subclass found it throws an error) @@ -49,46 +82,64 @@ def __init__(self, **kwargs): item_annotation_keys = annotation.keys() key_subclass_combos = defaultdict(list) for subclass in subclass_registry.values(): - # Get all required keys from subclass annotation_keys = [] for k, field in subclass.model_fields.items(): if field.default == PydanticUndefined and k != "uuid": - if hasattr(field, "alias") and field.alias in item_annotation_keys: + if ( + hasattr(field, "alias") + and field.alias in item_annotation_keys + ): annotation_keys.append(field.alias) - elif hasattr(field, "validation_alias") and field.validation_alias in item_annotation_keys: + elif ( + hasattr(field, "validation_alias") + and field.validation_alias + in item_annotation_keys + ): annotation_keys.append(field.validation_alias) else: annotation_keys.append(k) - + key_subclass_combos[subclass].extend(annotation_keys) - + # Sort by subclass that has the most keys i.e. the one with the most keys that matches is most likely our subclass - key_subclass_combos = dict(sorted(key_subclass_combos.items(), key = lambda x : len(x[1]), reverse=True)) + key_subclass_combos = dict( + sorted( + key_subclass_combos.items(), + key=lambda x: len(x[1]), + reverse=True, + ) + ) for subclass, key_subclass_combo in key_subclass_combos.items(): # Choose the keys from our dict we supplied that matches the required keys of a subclass - check_required_keys = all(key in list(item_annotation_keys) for key in key_subclass_combo) + check_required_keys = all( + key in list(item_annotation_keys) + for key in key_subclass_combo + ) if check_required_keys: # Keep trying subclasses until we find one that has valid values (does not throw an validation error) with suppress(ValidationError): annotation = subclass(**annotation) break if isinstance(annotation, dict): - raise ValueError(f"Could not find subclass for fields: {item_annotation_keys}") - + raise ValueError( + f"Could not find subclass for fields: {item_annotation_keys}" + ) + kwargs["annotations"][index] = annotation super().__init__(**kwargs) - class _Relationship(BaseModel): """This object holds information about the relationship""" + ndjson: NDRelationship source: str target: str class _AnnotationGroup(BaseModel): """Stores all the annotations and relationships per datarow""" + data_row: DataRow = None ndjson_annotations: Dict[str, AnnotationType] = {} relationships: List["NDLabel._Relationship"] = [] @@ -97,7 +148,10 @@ def to_common(self) -> LabelGenerator: annotation_groups = defaultdict(NDLabel._AnnotationGroup) for ndjson_annotation in self.annotations: - key = ndjson_annotation.data_row.id or ndjson_annotation.data_row.global_key + key = ( + ndjson_annotation.data_row.id + or ndjson_annotation.data_row.global_key + ) group = annotation_groups[key] if isinstance(ndjson_annotation, NDRelationship): @@ -105,7 +159,9 @@ def to_common(self) -> LabelGenerator: NDLabel._Relationship( ndjson=ndjson_annotation, source=ndjson_annotation.relationship.source, - target=ndjson_annotation.relationship.target)) + target=ndjson_annotation.relationship.target, + ) + ) else: # if this is the first object in this group, we # take note of the DataRow this group belongs to @@ -117,17 +173,22 @@ def to_common(self) -> LabelGenerator: # we need to change the value type of # `_AnnotationGroupTuple.ndjson_objects` to accept a list of objects # and adapt the code to support duplicate UUIDs - assert ndjson_annotation.uuid not in group.ndjson_annotations, f"UUID '{ndjson_annotation.uuid}' is not unique" + assert ( + ndjson_annotation.uuid not in group.ndjson_annotations + ), f"UUID '{ndjson_annotation.uuid}' is not unique" - group.ndjson_annotations[ - ndjson_annotation.uuid] = ndjson_annotation + group.ndjson_annotations[ndjson_annotation.uuid] = ( + ndjson_annotation + ) return LabelGenerator( - data=self._generate_annotations(annotation_groups)) + data=self._generate_annotations(annotation_groups) + ) @classmethod - def from_common(cls, - data: LabelCollection) -> Generator["NDLabel", None, None]: + def from_common( + cls, data: LabelCollection + ) -> Generator["NDLabel", None, None]: for label in data: yield from cls._create_non_video_annotations(label) yield from cls._create_video_annotations(label) @@ -144,68 +205,96 @@ def _generate_annotations( for uuid, ndjson_annotation in group.ndjson_annotations.items(): if isinstance(ndjson_annotation, NDDicomSegments): annotations.extend( - NDDicomSegments.to_common(ndjson_annotation, - ndjson_annotation.name, - ndjson_annotation.schema_id)) + NDDicomSegments.to_common( + ndjson_annotation, + ndjson_annotation.name, + ndjson_annotation.schema_id, + ) + ) elif isinstance(ndjson_annotation, NDSegments): annotations.extend( - NDSegments.to_common(ndjson_annotation, - ndjson_annotation.name, - ndjson_annotation.schema_id)) + NDSegments.to_common( + ndjson_annotation, + ndjson_annotation.name, + ndjson_annotation.schema_id, + ) + ) elif isinstance(ndjson_annotation, NDDicomMasks): annotations.append( - NDDicomMasks.to_common(ndjson_annotation)) + NDDicomMasks.to_common(ndjson_annotation) + ) elif isinstance(ndjson_annotation, NDVideoMasks): annotations.append( - NDVideoMasks.to_common(ndjson_annotation)) + NDVideoMasks.to_common(ndjson_annotation) + ) elif isinstance(ndjson_annotation, NDObjectType.__args__): annotation = NDObject.to_common(ndjson_annotation) annotations.append(annotation) relationship_annotations[uuid] = annotation - elif isinstance(ndjson_annotation, - NDClassificationType.__args__): + elif isinstance( + ndjson_annotation, NDClassificationType.__args__ + ): annotations.extend( - NDClassification.to_common(ndjson_annotation)) - elif isinstance(ndjson_annotation, - (NDScalarMetric, NDConfusionMatrixMetric)): + NDClassification.to_common(ndjson_annotation) + ) + elif isinstance( + ndjson_annotation, (NDScalarMetric, NDConfusionMatrixMetric) + ): annotations.append( - NDMetricAnnotation.to_common(ndjson_annotation)) + NDMetricAnnotation.to_common(ndjson_annotation) + ) elif isinstance(ndjson_annotation, NDPromptClassificationType): - annotation = NDPromptClassification.to_common(ndjson_annotation) + annotation = NDPromptClassification.to_common( + ndjson_annotation + ) annotations.append(annotation) elif isinstance(ndjson_annotation, NDMessageTask): annotations.append(ndjson_annotation.to_common()) else: raise TypeError( - f"Unsupported annotation. {type(ndjson_annotation)}") + f"Unsupported annotation. {type(ndjson_annotation)}" + ) # after all the annotations have been discovered, we can now create # the relationship objects and use references to the objects # involved for relationship in group.relationships: try: - source, target = relationship_annotations[ - relationship.source], relationship_annotations[ - relationship.target] + source, target = ( + relationship_annotations[relationship.source], + relationship_annotations[relationship.target], + ) except KeyError: raise ValueError( f"Relationship object refers to nonexistent object with UUID '{relationship.source}' and/or '{relationship.target}'" ) annotations.append( - NDRelationship.to_common(relationship.ndjson, source, - target)) + NDRelationship.to_common( + relationship.ndjson, source, target + ) + ) - yield Label(annotations=annotations, - data=self._infer_media_type(group.data_row, - annotations)) + yield Label( + annotations=annotations, + data=self._infer_media_type(group.data_row, annotations), + ) def _infer_media_type( - self, data_row: DataRow, - annotations: List[Union[TextEntity, ConversationEntity, - VideoClassificationAnnotation, - DICOMObjectAnnotation, VideoObjectAnnotation, - ObjectAnnotation, ClassificationAnnotation, - ScalarMetric, ConfusionMatrixMetric]] + self, + data_row: DataRow, + annotations: List[ + Union[ + TextEntity, + ConversationEntity, + VideoClassificationAnnotation, + DICOMObjectAnnotation, + VideoObjectAnnotation, + ObjectAnnotation, + ClassificationAnnotation, + ScalarMetric, + ConfusionMatrixMetric, + ] + ], ) -> Union[TextData, VideoData, ImageData]: if len(annotations) == 0: raise ValueError("Missing annotations while inferring media type") @@ -214,7 +303,10 @@ def _infer_media_type( data = GenericDataRowData if (TextEntity in types) or (ConversationEntity in types): data = TextData - elif VideoClassificationAnnotation in types or VideoObjectAnnotation in types: + elif ( + VideoClassificationAnnotation in types + or VideoObjectAnnotation in types + ): data = VideoData elif DICOMObjectAnnotation in types: data = DicomData @@ -226,7 +318,8 @@ def _infer_media_type( @staticmethod def _get_consecutive_frames( - frames_indices: List[int]) -> List[Tuple[int, int]]: + frames_indices: List[int], + ) -> List[Tuple[int, int]]: consecutive = [] for k, g in groupby(enumerate(frames_indices), lambda x: x[0] - x[1]): group = list(map(itemgetter(1), g)) @@ -235,18 +328,23 @@ def _get_consecutive_frames( @classmethod def _get_segment_frame_ranges( - cls, annotation_group: List[Union[VideoClassificationAnnotation, - VideoObjectAnnotation]] + cls, + annotation_group: List[ + Union[VideoClassificationAnnotation, VideoObjectAnnotation] + ], ) -> List[Tuple[int, int]]: - sorted_frame_segment_indices = sorted([ - (annotation.frame, annotation.segment_index) - for annotation in annotation_group - if annotation.segment_index is not None - ]) + sorted_frame_segment_indices = sorted( + [ + (annotation.frame, annotation.segment_index) + for annotation in annotation_group + if annotation.segment_index is not None + ] + ) if len(sorted_frame_segment_indices) == 0: # Group segment by consecutive frames, since `segment_index` is not present return cls._get_consecutive_frames( - sorted([annotation.frame for annotation in annotation_group])) + sorted([annotation.frame for annotation in annotation_group]) + ) elif len(sorted_frame_segment_indices) == len(annotation_group): # Group segment by segment_index last_segment_id = 0 @@ -264,32 +362,34 @@ def _get_segment_frame_ranges( return frame_ranges else: raise ValueError( - f"Video annotations cannot partially have `segment_index` set") + f"Video annotations cannot partially have `segment_index` set" + ) @classmethod def _create_video_annotations( cls, label: Label ) -> Generator[Union[NDChecklistSubclass, NDRadioSubclass], None, None]: - video_annotations = defaultdict(list) for annot in label.annotations: if isinstance( - annot, - (VideoClassificationAnnotation, VideoObjectAnnotation)): - video_annotations[annot.feature_schema_id or - annot.name].append(annot) + annot, (VideoClassificationAnnotation, VideoObjectAnnotation) + ): + video_annotations[annot.feature_schema_id or annot.name].append( + annot + ) elif isinstance(annot, VideoMaskAnnotation): yield NDObject.from_common(annotation=annot, data=label.data) for annotation_group in video_annotations.values(): segment_frame_ranges = cls._get_segment_frame_ranges( - annotation_group) + annotation_group + ) if isinstance(annotation_group[0], VideoClassificationAnnotation): annotation = annotation_group[0] frames_data = [] for frames in segment_frame_ranges: - frames_data.append({'start': frames[0], 'end': frames[-1]}) - annotation.extra.update({'frames': frames_data}) + frames_data.append({"start": frames[0], "end": frames[-1]}) + annotation.extra.update({"frames": frames_data}) yield NDClassification.from_common(annotation, label.data) elif isinstance(annotation_group[0], VideoObjectAnnotation): @@ -297,7 +397,10 @@ def _create_video_annotations( for start_frame, end_frame in segment_frame_ranges: segment = [] for annotation in annotation_group: - if annotation.keyframe and start_frame <= annotation.frame <= end_frame: + if ( + annotation.keyframe + and start_frame <= annotation.frame <= end_frame + ): segment.append(annotation) segments.append(segment) yield NDObject.from_common(segments, label.data) @@ -305,10 +408,16 @@ def _create_video_annotations( @classmethod def _create_non_video_annotations(cls, label: Label): non_video_annotations = [ - annot for annot in label.annotations - if not isinstance(annot, (VideoClassificationAnnotation, - VideoObjectAnnotation, - VideoMaskAnnotation)) + annot + for annot in label.annotations + if not isinstance( + annot, + ( + VideoClassificationAnnotation, + VideoObjectAnnotation, + VideoMaskAnnotation, + ), + ) ] for annotation in non_video_annotations: if isinstance(annotation, ClassificationAnnotation): diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py index 9fd90544c..60d538b19 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py @@ -3,11 +3,17 @@ from labelbox.data.annotation_types.data import ImageData, TextData from labelbox.data.serialization.ndjson.base import DataRow, NDJsonBase from labelbox.data.annotation_types.metrics.scalar import ( - ScalarMetric, ScalarMetricAggregation, ScalarMetricValue, - ScalarMetricConfidenceValue) + ScalarMetric, + ScalarMetricAggregation, + ScalarMetricValue, + ScalarMetricConfidenceValue, +) from labelbox.data.annotation_types.metrics.confusion_matrix import ( - ConfusionMatrixAggregation, ConfusionMatrixMetric, - ConfusionMatrixMetricValue, ConfusionMatrixMetricConfidenceValue) + ConfusionMatrixAggregation, + ConfusionMatrixMetric, + ConfusionMatrixMetricValue, + ConfusionMatrixMetricConfidenceValue, +) from pydantic import ConfigDict, model_serializer from .base import _SubclassRegistryBase @@ -16,71 +22,82 @@ class BaseNDMetric(NDJsonBase): metric_value: float feature_name: Optional[str] = None subclass_name: Optional[str] = None - model_config = ConfigDict(use_enum_values = True) + model_config = ConfigDict(use_enum_values=True) - @model_serializer(mode = "wrap") + @model_serializer(mode="wrap") def serialize_model(self, handler): res = handler(self) - for field in ['featureName', 'subclassName']: + for field in ["featureName", "subclassName"]: if field in res and res[field] is None: res.pop(field) return res class NDConfusionMatrixMetric(BaseNDMetric, _SubclassRegistryBase): - metric_value: Union[ConfusionMatrixMetricValue, - ConfusionMatrixMetricConfidenceValue] + metric_value: Union[ + ConfusionMatrixMetricValue, ConfusionMatrixMetricConfidenceValue + ] metric_name: str aggregation: ConfusionMatrixAggregation def to_common(self) -> ConfusionMatrixMetric: - return ConfusionMatrixMetric(value=self.metric_value, - metric_name=self.metric_name, - feature_name=self.feature_name, - subclass_name=self.subclass_name, - aggregation=self.aggregation, - extra={'uuid': self.uuid}) + return ConfusionMatrixMetric( + value=self.metric_value, + metric_name=self.metric_name, + feature_name=self.feature_name, + subclass_name=self.subclass_name, + aggregation=self.aggregation, + extra={"uuid": self.uuid}, + ) @classmethod def from_common( - cls, metric: ConfusionMatrixMetric, - data: Union[TextData, ImageData]) -> "NDConfusionMatrixMetric": - return cls(uuid=metric.extra.get('uuid'), - metric_value=metric.value, - metric_name=metric.metric_name, - feature_name=metric.feature_name, - subclass_name=metric.subclass_name, - aggregation=metric.aggregation, - data_row=DataRow(id=data.uid, global_key=data.global_key)) + cls, metric: ConfusionMatrixMetric, data: Union[TextData, ImageData] + ) -> "NDConfusionMatrixMetric": + return cls( + uuid=metric.extra.get("uuid"), + metric_value=metric.value, + metric_name=metric.metric_name, + feature_name=metric.feature_name, + subclass_name=metric.subclass_name, + aggregation=metric.aggregation, + data_row=DataRow(id=data.uid, global_key=data.global_key), + ) class NDScalarMetric(BaseNDMetric, _SubclassRegistryBase): metric_value: Union[ScalarMetricValue, ScalarMetricConfidenceValue] metric_name: Optional[str] = None - aggregation: Optional[ScalarMetricAggregation] = ScalarMetricAggregation.ARITHMETIC_MEAN + aggregation: Optional[ScalarMetricAggregation] = ( + ScalarMetricAggregation.ARITHMETIC_MEAN + ) def to_common(self) -> ScalarMetric: - return ScalarMetric(value=self.metric_value, - metric_name=self.metric_name, - feature_name=self.feature_name, - subclass_name=self.subclass_name, - aggregation=self.aggregation, - extra={'uuid': self.uuid}) + return ScalarMetric( + value=self.metric_value, + metric_name=self.metric_name, + feature_name=self.feature_name, + subclass_name=self.subclass_name, + aggregation=self.aggregation, + extra={"uuid": self.uuid}, + ) @classmethod - def from_common(cls, metric: ScalarMetric, - data: Union[TextData, ImageData]) -> "NDScalarMetric": - return cls(uuid=metric.extra.get('uuid'), - metric_value=metric.value, - metric_name=metric.metric_name, - feature_name=metric.feature_name, - subclass_name=metric.subclass_name, - aggregation=metric.aggregation.value, - data_row=DataRow(id=data.uid, global_key=data.global_key)) + def from_common( + cls, metric: ScalarMetric, data: Union[TextData, ImageData] + ) -> "NDScalarMetric": + return cls( + uuid=metric.extra.get("uuid"), + metric_value=metric.value, + metric_name=metric.metric_name, + feature_name=metric.feature_name, + subclass_name=metric.subclass_name, + aggregation=metric.aggregation.value, + data_row=DataRow(id=data.uid, global_key=data.global_key), + ) class NDMetricAnnotation: - @classmethod def to_common( cls, annotation: Union[NDScalarMetric, NDConfusionMatrixMetric] @@ -89,16 +106,16 @@ def to_common( @classmethod def from_common( - cls, annotation: Union[ScalarMetric, - ConfusionMatrixMetric], data: Union[TextData, - ImageData] + cls, + annotation: Union[ScalarMetric, ConfusionMatrixMetric], + data: Union[TextData, ImageData], ) -> Union[NDScalarMetric, NDConfusionMatrixMetric]: obj = cls.lookup_object(annotation) return obj.from_common(annotation, data) @staticmethod def lookup_object( - annotation: Union[ScalarMetric, ConfusionMatrixMetric] + annotation: Union[ScalarMetric, ConfusionMatrixMetric], ) -> Union[Type[NDScalarMetric], Type[NDConfusionMatrixMetric]]: result = { ScalarMetric: NDScalarMetric, @@ -106,5 +123,6 @@ def lookup_object( }.get(type(annotation)) if result is None: raise TypeError( - f"Unable to convert object to MAL format. `{type(annotation)}`") + f"Unable to convert object to MAL format. `{type(annotation)}`" + ) return result diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py index 7b1908b76..4cb797f38 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py @@ -4,17 +4,24 @@ from .base import _SubclassRegistryBase, DataRow, NDAnnotation from ...annotation_types.types import Cuid -from ...annotation_types.mmc import MessageSingleSelectionTask, MessageMultiSelectionTask, MessageRankingTask, MessageEvaluationTaskAnnotation +from ...annotation_types.mmc import ( + MessageSingleSelectionTask, + MessageMultiSelectionTask, + MessageRankingTask, + MessageEvaluationTaskAnnotation, +) class MessageTaskData(_CamelCaseMixin): format: str - data: Union[MessageSingleSelectionTask, MessageMultiSelectionTask, - MessageRankingTask] + data: Union[ + MessageSingleSelectionTask, + MessageMultiSelectionTask, + MessageRankingTask, + ] class NDMessageTask(NDAnnotation, _SubclassRegistryBase): - message_evaluation_task: MessageTaskData def to_common(self) -> MessageEvaluationTaskAnnotation: @@ -27,13 +34,16 @@ def to_common(self) -> MessageEvaluationTaskAnnotation: @classmethod def from_common( - cls, - annotation: MessageEvaluationTaskAnnotation, - data: Any #Union[ImageData, TextData], + cls, + annotation: MessageEvaluationTaskAnnotation, + data: Any, # Union[ImageData, TextData], ) -> "NDMessageTask": - return cls(uuid=str(annotation._uuid), - name=annotation.name, - schema_id=annotation.feature_schema_id, - data_row=DataRow(id=data.uid, global_key=data.global_key), - message_evaluation_task=MessageTaskData( - format=annotation.value.format, data=annotation.value)) + return cls( + uuid=str(annotation._uuid), + name=annotation.name, + schema_id=annotation.feature_schema_id, + data_row=DataRow(id=data.uid, global_key=data.global_key), + message_evaluation_task=MessageTaskData( + format=annotation.value.format, data=annotation.value + ), + ) diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py index 2b32f1c2b..79e9b4adf 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py @@ -2,9 +2,19 @@ from typing import Any, Dict, List, Tuple, Union, Optional import base64 -from labelbox.data.annotation_types.ner.conversation_entity import ConversationEntity -from labelbox.data.annotation_types.video import VideoObjectAnnotation, DICOMObjectAnnotation -from labelbox.data.mixins import ConfidenceMixin, CustomMetricsMixin, CustomMetric, CustomMetricsNotSupportedMixin +from labelbox.data.annotation_types.ner.conversation_entity import ( + ConversationEntity, +) +from labelbox.data.annotation_types.video import ( + VideoObjectAnnotation, + DICOMObjectAnnotation, +) +from labelbox.data.mixins import ( + ConfidenceMixin, + CustomMetricsMixin, + CustomMetric, + CustomMetricsNotSupportedMixin, +) import numpy as np from PIL import Image @@ -13,12 +23,35 @@ from labelbox.data.annotation_types.data.video import VideoData from ...annotation_types.data import ImageData, TextData, MaskData -from ...annotation_types.ner import DocumentEntity, DocumentTextSelection, TextEntity +from ...annotation_types.ner import ( + DocumentEntity, + DocumentTextSelection, + TextEntity, +) from ...annotation_types.types import Cuid -from ...annotation_types.geometry import DocumentRectangle, Rectangle, Polygon, Line, Point, Mask -from ...annotation_types.annotation import ClassificationAnnotation, ObjectAnnotation -from ...annotation_types.video import VideoMaskAnnotation, DICOMMaskAnnotation, MaskFrame, MaskInstance -from .classification import NDClassification, NDSubclassification, NDSubclassificationType +from ...annotation_types.geometry import ( + DocumentRectangle, + Rectangle, + Polygon, + Line, + Point, + Mask, +) +from ...annotation_types.annotation import ( + ClassificationAnnotation, + ObjectAnnotation, +) +from ...annotation_types.video import ( + VideoMaskAnnotation, + DICOMMaskAnnotation, + MaskFrame, + MaskInstance, +) +from .classification import ( + NDClassification, + NDSubclassification, + NDSubclassificationType, +) from .base import DataRow, NDAnnotation, NDJsonBase, _SubclassRegistryBase from pydantic import BaseModel @@ -48,7 +81,9 @@ class Bbox(BaseModel): width: float -class NDPoint(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase): +class NDPoint( + NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase +): point: _Point def to_common(self) -> Point: @@ -56,46 +91,48 @@ def to_common(self) -> Point: @classmethod def from_common( - cls, - uuid: str, - point: Point, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None) -> "NDPoint": - return cls(point={ - 'x': point.x, - 'y': point.y - }, - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) + cls, + uuid: str, + point: Point, + classifications: List[ClassificationAnnotation], + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + data: Union[ImageData, TextData], + confidence: Optional[float] = None, + custom_metrics: Optional[List[CustomMetric]] = None, + ) -> "NDPoint": + return cls( + point={"x": point.x, "y": point.y}, + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + classifications=classifications, + confidence=confidence, + custom_metrics=custom_metrics, + ) class NDFramePoint(VideoSupported, _SubclassRegistryBase): point: _Point classifications: List[NDSubclassificationType] = [] - def to_common(self, name: str, feature_schema_id: Cuid, - segment_index: int) -> VideoObjectAnnotation: - return VideoObjectAnnotation(frame=self.frame, - segment_index=segment_index, - keyframe=True, - name=name, - feature_schema_id=feature_schema_id, - value=Point(x=self.point.x, - y=self.point.y), - classifications=[ - NDSubclassification.to_common(annot) - for annot in self.classifications - ]) + def to_common( + self, name: str, feature_schema_id: Cuid, segment_index: int + ) -> VideoObjectAnnotation: + return VideoObjectAnnotation( + frame=self.frame, + segment_index=segment_index, + keyframe=True, + name=name, + feature_schema_id=feature_schema_id, + value=Point(x=self.point.x, y=self.point.y), + classifications=[ + NDSubclassification.to_common(annot) + for annot in self.classifications + ], + ) @classmethod def from_common( @@ -104,12 +141,16 @@ def from_common( point: Point, classifications: List[NDSubclassificationType], ): - return cls(frame=frame, - point=_Point(x=point.x, y=point.y), - classifications=classifications) + return cls( + frame=frame, + point=_Point(x=point.x, y=point.y), + classifications=classifications, + ) -class NDLine(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase): +class NDLine( + NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase +): line: List[_Point] def to_common(self) -> Line: @@ -117,35 +158,36 @@ def to_common(self) -> Line: @classmethod def from_common( - cls, - uuid: str, - line: Line, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None) -> "NDLine": - return cls(line=[{ - 'x': pt.x, - 'y': pt.y - } for pt in line.points], - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) + cls, + uuid: str, + line: Line, + classifications: List[ClassificationAnnotation], + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + data: Union[ImageData, TextData], + confidence: Optional[float] = None, + custom_metrics: Optional[List[CustomMetric]] = None, + ) -> "NDLine": + return cls( + line=[{"x": pt.x, "y": pt.y} for pt in line.points], + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + classifications=classifications, + confidence=confidence, + custom_metrics=custom_metrics, + ) class NDFrameLine(VideoSupported, _SubclassRegistryBase): line: List[_Point] classifications: List[NDSubclassificationType] = [] - def to_common(self, name: str, feature_schema_id: Cuid, - segment_index: int) -> VideoObjectAnnotation: + def to_common( + self, name: str, feature_schema_id: Cuid, segment_index: int + ) -> VideoObjectAnnotation: return VideoObjectAnnotation( frame=self.frame, segment_index=segment_index, @@ -156,7 +198,8 @@ def to_common(self, name: str, feature_schema_id: Cuid, classifications=[ NDSubclassification.to_common(annot) for annot in self.classifications - ]) + ], + ) @classmethod def from_common( @@ -165,18 +208,21 @@ def from_common( line: Line, classifications: List[NDSubclassificationType], ): - return cls(frame=frame, - line=[{ - 'x': pt.x, - 'y': pt.y - } for pt in line.points], - classifications=classifications) + return cls( + frame=frame, + line=[{"x": pt.x, "y": pt.y} for pt in line.points], + classifications=classifications, + ) class NDDicomLine(NDFrameLine, _SubclassRegistryBase): - - def to_common(self, name: str, feature_schema_id: Cuid, segment_index: int, - group_key: str) -> DICOMObjectAnnotation: + def to_common( + self, + name: str, + feature_schema_id: Cuid, + segment_index: int, + group_key: str, + ) -> DICOMObjectAnnotation: return DICOMObjectAnnotation( frame=self.frame, segment_index=segment_index, @@ -184,10 +230,13 @@ def to_common(self, name: str, feature_schema_id: Cuid, segment_index: int, name=name, feature_schema_id=feature_schema_id, value=Line(points=[Point(x=pt.x, y=pt.y) for pt in self.line]), - group_key=group_key) + group_key=group_key, + ) -class NDPolygon(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase): +class NDPolygon( + NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase +): polygon: List[_Point] def to_common(self) -> Polygon: @@ -195,63 +244,73 @@ def to_common(self) -> Polygon: @classmethod def from_common( - cls, - uuid: str, - polygon: Polygon, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None) -> "NDPolygon": - return cls(polygon=[{ - 'x': pt.x, - 'y': pt.y - } for pt in polygon.points], - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) - - -class NDRectangle(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase): + cls, + uuid: str, + polygon: Polygon, + classifications: List[ClassificationAnnotation], + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + data: Union[ImageData, TextData], + confidence: Optional[float] = None, + custom_metrics: Optional[List[CustomMetric]] = None, + ) -> "NDPolygon": + return cls( + polygon=[{"x": pt.x, "y": pt.y} for pt in polygon.points], + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + classifications=classifications, + confidence=confidence, + custom_metrics=custom_metrics, + ) + + +class NDRectangle( + NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase +): bbox: Bbox def to_common(self) -> Rectangle: - return Rectangle(start=Point(x=self.bbox.left, y=self.bbox.top), - end=Point(x=self.bbox.left + self.bbox.width, - y=self.bbox.top + self.bbox.height)) + return Rectangle( + start=Point(x=self.bbox.left, y=self.bbox.top), + end=Point( + x=self.bbox.left + self.bbox.width, + y=self.bbox.top + self.bbox.height, + ), + ) @classmethod def from_common( - cls, - uuid: str, - rectangle: Rectangle, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None + cls, + uuid: str, + rectangle: Rectangle, + classifications: List[ClassificationAnnotation], + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + data: Union[ImageData, TextData], + confidence: Optional[float] = None, + custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDRectangle": - return cls(bbox=Bbox(top=min(rectangle.start.y, rectangle.end.y), - left=min(rectangle.start.x, rectangle.end.x), - height=abs(rectangle.end.y - rectangle.start.y), - width=abs(rectangle.end.x - rectangle.start.x)), - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - page=extra.get('page'), - unit=extra.get('unit'), - confidence=confidence, - custom_metrics=custom_metrics) + return cls( + bbox=Bbox( + top=min(rectangle.start.y, rectangle.end.y), + left=min(rectangle.start.x, rectangle.end.x), + height=abs(rectangle.end.y - rectangle.start.y), + width=abs(rectangle.end.x - rectangle.start.x), + ), + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + classifications=classifications, + page=extra.get("page"), + unit=extra.get("unit"), + confidence=confidence, + custom_metrics=custom_metrics, + ) class NDDocumentRectangle(NDRectangle, _SubclassRegistryBase): @@ -259,59 +318,73 @@ class NDDocumentRectangle(NDRectangle, _SubclassRegistryBase): unit: str def to_common(self) -> DocumentRectangle: - return DocumentRectangle(start=Point(x=self.bbox.left, y=self.bbox.top), - end=Point(x=self.bbox.left + self.bbox.width, - y=self.bbox.top + self.bbox.height), - page=self.page, - unit=self.unit) + return DocumentRectangle( + start=Point(x=self.bbox.left, y=self.bbox.top), + end=Point( + x=self.bbox.left + self.bbox.width, + y=self.bbox.top + self.bbox.height, + ), + page=self.page, + unit=self.unit, + ) @classmethod def from_common( - cls, - uuid: str, - rectangle: DocumentRectangle, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None + cls, + uuid: str, + rectangle: DocumentRectangle, + classifications: List[ClassificationAnnotation], + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + data: Union[ImageData, TextData], + confidence: Optional[float] = None, + custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDRectangle": - return cls(bbox=Bbox(top=min(rectangle.start.y, rectangle.end.y), - left=min(rectangle.start.x, rectangle.end.x), - height=abs(rectangle.end.y - rectangle.start.y), - width=abs(rectangle.end.x - rectangle.start.x)), - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - page=rectangle.page, - unit=rectangle.unit.value, - confidence=confidence, - custom_metrics=custom_metrics) + return cls( + bbox=Bbox( + top=min(rectangle.start.y, rectangle.end.y), + left=min(rectangle.start.x, rectangle.end.x), + height=abs(rectangle.end.y - rectangle.start.y), + width=abs(rectangle.end.x - rectangle.start.x), + ), + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + classifications=classifications, + page=rectangle.page, + unit=rectangle.unit.value, + confidence=confidence, + custom_metrics=custom_metrics, + ) class NDFrameRectangle(VideoSupported, _SubclassRegistryBase): bbox: Bbox classifications: List[NDSubclassificationType] = [] - def to_common(self, name: str, feature_schema_id: Cuid, - segment_index: int) -> VideoObjectAnnotation: + def to_common( + self, name: str, feature_schema_id: Cuid, segment_index: int + ) -> VideoObjectAnnotation: return VideoObjectAnnotation( frame=self.frame, segment_index=segment_index, keyframe=True, name=name, feature_schema_id=feature_schema_id, - value=Rectangle(start=Point(x=self.bbox.left, y=self.bbox.top), - end=Point(x=self.bbox.left + self.bbox.width, - y=self.bbox.top + self.bbox.height)), + value=Rectangle( + start=Point(x=self.bbox.left, y=self.bbox.top), + end=Point( + x=self.bbox.left + self.bbox.width, + y=self.bbox.top + self.bbox.height, + ), + ), classifications=[ NDSubclassification.to_common(annot) for annot in self.classifications - ]) + ], + ) @classmethod def from_common( @@ -320,12 +393,16 @@ def from_common( rectangle: Rectangle, classifications: List[NDSubclassificationType], ): - return cls(frame=frame, - bbox=Bbox(top=min(rectangle.start.y, rectangle.end.y), - left=min(rectangle.start.x, rectangle.end.x), - height=abs(rectangle.end.y - rectangle.start.y), - width=abs(rectangle.end.x - rectangle.start.x)), - classifications=classifications) + return cls( + frame=frame, + bbox=Bbox( + top=min(rectangle.start.y, rectangle.end.y), + left=min(rectangle.start.x, rectangle.end.x), + height=abs(rectangle.end.y - rectangle.start.y), + width=abs(rectangle.end.x - rectangle.start.x), + ), + classifications=classifications, + ) class NDSegment(BaseModel): @@ -343,19 +420,25 @@ def lookup_segment_object_type(segment: List) -> "NDFrameObjectType": return result @staticmethod - def segment_with_uuid(keyframe: Union[NDFrameRectangle, NDFramePoint, - NDFrameLine], uuid: str): + def segment_with_uuid( + keyframe: Union[NDFrameRectangle, NDFramePoint, NDFrameLine], uuid: str + ): keyframe._uuid = uuid - keyframe.extra = {'uuid': uuid} + keyframe.extra = {"uuid": uuid} return keyframe - def to_common(self, name: str, feature_schema_id: Cuid, uuid: str, - segment_index: int): + def to_common( + self, name: str, feature_schema_id: Cuid, uuid: str, segment_index: int + ): return [ self.segment_with_uuid( - keyframe.to_common(name=name, - feature_schema_id=feature_schema_id, - segment_index=segment_index), uuid) + keyframe.to_common( + name=name, + feature_schema_id=feature_schema_id, + segment_index=segment_index, + ), + uuid, + ) for keyframe in self.keyframes ] @@ -363,14 +446,19 @@ def to_common(self, name: str, feature_schema_id: Cuid, uuid: str, def from_common(cls, segment): nd_frame_object_type = cls.lookup_segment_object_type(segment) - return cls(keyframes=[ - nd_frame_object_type.from_common( - object_annotation.frame, object_annotation.value, [ - NDSubclassification.from_common(annot) - for annot in object_annotation.classifications - ]) - for object_annotation in segment - ]) + return cls( + keyframes=[ + nd_frame_object_type.from_common( + object_annotation.frame, + object_annotation.value, + [ + NDSubclassification.from_common(annot) + for annot in object_annotation.classifications + ], + ) + for object_annotation in segment + ] + ) class NDDicomSegment(NDSegment): @@ -384,16 +472,26 @@ def lookup_segment_object_type(segment: List) -> "NDDicomObjectType": if segment_class == Line: return NDDicomLine else: - raise ValueError('DICOM segments only support Line objects') + raise ValueError("DICOM segments only support Line objects") - def to_common(self, name: str, feature_schema_id: Cuid, uuid: str, - segment_index: int, group_key: str): + def to_common( + self, + name: str, + feature_schema_id: Cuid, + uuid: str, + segment_index: int, + group_key: str, + ): return [ self.segment_with_uuid( - keyframe.to_common(name=name, - feature_schema_id=feature_schema_id, - segment_index=segment_index, - group_key=group_key), uuid) + keyframe.to_common( + name=name, + feature_schema_id=feature_schema_id, + segment_index=segment_index, + group_key=group_key, + ), + uuid, + ) for keyframe in self.keyframes ] @@ -405,24 +503,33 @@ def to_common(self, name: str, feature_schema_id: Cuid): result = [] for idx, segment in enumerate(self.segments): result.extend( - segment.to_common(name=name, - feature_schema_id=feature_schema_id, - segment_index=idx, - uuid=self.uuid)) + segment.to_common( + name=name, + feature_schema_id=feature_schema_id, + segment_index=idx, + uuid=self.uuid, + ) + ) return result @classmethod - def from_common(cls, segments: List[VideoObjectAnnotation], data: VideoData, - name: str, feature_schema_id: Cuid, - extra: Dict[str, Any]) -> "NDSegments": - + def from_common( + cls, + segments: List[VideoObjectAnnotation], + data: VideoData, + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + ) -> "NDSegments": segments = [NDSegment.from_common(segment) for segment in segments] - return cls(segments=segments, - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=extra.get('uuid')) + return cls( + segments=segments, + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=extra.get("uuid"), + ) class NDDicomSegments(NDBaseObject, DicomSupported, _SubclassRegistryBase): @@ -432,26 +539,36 @@ def to_common(self, name: str, feature_schema_id: Cuid): result = [] for idx, segment in enumerate(self.segments): result.extend( - segment.to_common(name=name, - feature_schema_id=feature_schema_id, - segment_index=idx, - uuid=self.uuid, - group_key=self.group_key)) + segment.to_common( + name=name, + feature_schema_id=feature_schema_id, + segment_index=idx, + uuid=self.uuid, + group_key=self.group_key, + ) + ) return result @classmethod - def from_common(cls, segments: List[DICOMObjectAnnotation], data: VideoData, - name: str, feature_schema_id: Cuid, extra: Dict[str, Any], - group_key: str) -> "NDDicomSegments": - + def from_common( + cls, + segments: List[DICOMObjectAnnotation], + data: VideoData, + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + group_key: str, + ) -> "NDDicomSegments": segments = [NDDicomSegment.from_common(segment) for segment in segments] - return cls(segments=segments, - dataRow=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=extra.get('uuid'), - group_key=group_key) + return cls( + segments=segments, + dataRow=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=extra.get("uuid"), + group_key=group_key, + ) class _URIMask(BaseModel): @@ -463,53 +580,61 @@ class _PNGMask(BaseModel): png: str -class NDMask(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase): +class NDMask( + NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase +): mask: Union[_URIMask, _PNGMask] def to_common(self) -> Mask: if isinstance(self.mask, _URIMask): - return Mask(mask=MaskData(url=self.mask.instanceURI), - color=self.mask.colorRGB) + return Mask( + mask=MaskData(url=self.mask.instanceURI), + color=self.mask.colorRGB, + ) else: - encoded_image_bytes = self.mask.png.encode('utf-8') + encoded_image_bytes = self.mask.png.encode("utf-8") image_bytes = base64.b64decode(encoded_image_bytes) image = np.array(Image.open(BytesIO(image_bytes))) if np.max(image) > 1: raise ValueError( - f"Expected binary mask. Found max value of {np.max(image)}") + f"Expected binary mask. Found max value of {np.max(image)}" + ) # Color is 1,1,1 because it is a binary array and we are just stacking it into 3 channels return Mask(mask=MaskData.from_2D_arr(image), color=(1, 1, 1)) @classmethod def from_common( - cls, - uuid: str, - mask: Mask, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None) -> "NDMask": - + cls, + uuid: str, + mask: Mask, + classifications: List[ClassificationAnnotation], + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + data: Union[ImageData, TextData], + confidence: Optional[float] = None, + custom_metrics: Optional[List[CustomMetric]] = None, + ) -> "NDMask": if mask.mask.url is not None: lbv1_mask = _URIMask(instanceURI=mask.mask.url, colorRGB=mask.color) else: binary = np.all(mask.mask.value == mask.color, axis=-1) im_bytes = BytesIO() - Image.fromarray(binary, 'L').save(im_bytes, format="PNG") + Image.fromarray(binary, "L").save(im_bytes, format="PNG") lbv1_mask = _PNGMask( - png=base64.b64encode(im_bytes.getvalue()).decode('utf-8')) + png=base64.b64encode(im_bytes.getvalue()).decode("utf-8") + ) - return cls(mask=lbv1_mask, - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) + return cls( + mask=lbv1_mask, + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + classifications=classifications, + confidence=confidence, + custom_metrics=custom_metrics, + ) class NDVideoMasksFramesInstances(BaseModel): @@ -517,14 +642,20 @@ class NDVideoMasksFramesInstances(BaseModel): instances: List[MaskInstance] -class NDVideoMasks(NDJsonBase, ConfidenceMixin, CustomMetricsNotSupportedMixin, _SubclassRegistryBase): +class NDVideoMasks( + NDJsonBase, + ConfidenceMixin, + CustomMetricsNotSupportedMixin, + _SubclassRegistryBase, +): masks: NDVideoMasksFramesInstances def to_common(self) -> VideoMaskAnnotation: for mask_frame in self.masks.frames: if mask_frame.im_bytes: mask_frame.im_bytes = base64.b64decode( - mask_frame.im_bytes.encode('utf-8')) + mask_frame.im_bytes.encode("utf-8") + ) return VideoMaskAnnotation( frames=self.masks.frames, @@ -536,17 +667,18 @@ def from_common(cls, annotation, data): for mask_frame in annotation.frames: if mask_frame.im_bytes: mask_frame.im_bytes = base64.b64encode( - mask_frame.im_bytes).decode('utf-8') + mask_frame.im_bytes + ).decode("utf-8") return cls( data_row=DataRow(id=data.uid, global_key=data.global_key), - masks=NDVideoMasksFramesInstances(frames=annotation.frames, - instances=annotation.instances), + masks=NDVideoMasksFramesInstances( + frames=annotation.frames, instances=annotation.instances + ), ) class NDDicomMasks(NDVideoMasks, DicomSupported, _SubclassRegistryBase): - def to_common(self) -> DICOMMaskAnnotation: return DICOMMaskAnnotation( frames=self.masks.frames, @@ -558,8 +690,9 @@ def to_common(self) -> DICOMMaskAnnotation: def from_common(cls, annotation, data): return cls( data_row=DataRow(id=data.uid, global_key=data.global_key), - masks=NDVideoMasksFramesInstances(frames=annotation.frames, - instances=annotation.instances), + masks=NDVideoMasksFramesInstances( + frames=annotation.frames, instances=annotation.instances + ), group_key=annotation.group_key.value, ) @@ -569,7 +702,9 @@ class Location(BaseModel): end: int -class NDTextEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase): +class NDTextEntity( + NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase +): location: Location def to_common(self) -> TextEntity: @@ -577,37 +712,42 @@ def to_common(self) -> TextEntity: @classmethod def from_common( - cls, - uuid: str, - text_entity: TextEntity, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None + cls, + uuid: str, + text_entity: TextEntity, + classifications: List[ClassificationAnnotation], + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + data: Union[ImageData, TextData], + confidence: Optional[float] = None, + custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDTextEntity": - return cls(location=Location( - start=text_entity.start, - end=text_entity.end, - ), - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) - - -class NDDocumentEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase): + return cls( + location=Location( + start=text_entity.start, + end=text_entity.end, + ), + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + classifications=classifications, + confidence=confidence, + custom_metrics=custom_metrics, + ) + + +class NDDocumentEntity( + NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase +): name: str text_selections: List[DocumentTextSelection] def to_common(self) -> DocumentEntity: - return DocumentEntity(name=self.name, - text_selections=self.text_selections) + return DocumentEntity( + name=self.name, text_selections=self.text_selections + ) @classmethod def from_common( @@ -620,26 +760,29 @@ def from_common( extra: Dict[str, Any], data: Union[ImageData, TextData], confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None + custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDDocumentEntity": - - return cls(text_selections=document_entity.text_selections, - dataRow=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) + return cls( + text_selections=document_entity.text_selections, + dataRow=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + classifications=classifications, + confidence=confidence, + custom_metrics=custom_metrics, + ) class NDConversationEntity(NDTextEntity, _SubclassRegistryBase): message_id: str def to_common(self) -> ConversationEntity: - return ConversationEntity(start=self.location.start, - end=self.location.end, - message_id=self.message_id) + return ConversationEntity( + start=self.location.start, + end=self.location.end, + message_id=self.message_id, + ) @classmethod def from_common( @@ -652,22 +795,24 @@ def from_common( extra: Dict[str, Any], data: Union[ImageData, TextData], confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None + custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDConversationEntity": - return cls(location=Location(start=conversation_entity.start, - end=conversation_entity.end), - message_id=conversation_entity.message_id, - dataRow=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) + return cls( + location=Location( + start=conversation_entity.start, end=conversation_entity.end + ), + message_id=conversation_entity.message_id, + dataRow=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + classifications=classifications, + confidence=confidence, + custom_metrics=custom_metrics, + ) class NDObject: - @staticmethod def to_common(annotation: "NDObjectType") -> ObjectAnnotation: common_annotation = annotation.to_common() @@ -675,49 +820,66 @@ def to_common(annotation: "NDObjectType") -> ObjectAnnotation: NDSubclassification.to_common(annot) for annot in annotation.classifications ] - confidence = annotation.confidence if hasattr(annotation, - 'confidence') else None - - custom_metrics = annotation.custom_metrics if hasattr( - annotation, 'custom_metrics') else None - return ObjectAnnotation(value=common_annotation, - name=annotation.name, - feature_schema_id=annotation.schema_id, - classifications=classifications, - extra={ - 'uuid': annotation.uuid, - 'page': annotation.page, - 'unit': annotation.unit - }, - confidence=confidence, - custom_metrics=custom_metrics) + confidence = ( + annotation.confidence if hasattr(annotation, "confidence") else None + ) + + custom_metrics = ( + annotation.custom_metrics + if hasattr(annotation, "custom_metrics") + else None + ) + return ObjectAnnotation( + value=common_annotation, + name=annotation.name, + feature_schema_id=annotation.schema_id, + classifications=classifications, + extra={ + "uuid": annotation.uuid, + "page": annotation.page, + "unit": annotation.unit, + }, + confidence=confidence, + custom_metrics=custom_metrics, + ) @classmethod def from_common( cls, - annotation: Union[ObjectAnnotation, List[List[VideoObjectAnnotation]], - VideoMaskAnnotation], data: Union[ImageData, TextData] - ) -> Union[NDLine, NDPoint, NDPolygon, NDDocumentRectangle, NDRectangle, - NDMask, NDTextEntity]: + annotation: Union[ + ObjectAnnotation, + List[List[VideoObjectAnnotation]], + VideoMaskAnnotation, + ], + data: Union[ImageData, TextData], + ) -> Union[ + NDLine, + NDPoint, + NDPolygon, + NDDocumentRectangle, + NDRectangle, + NDMask, + NDTextEntity, + ]: obj = cls.lookup_object(annotation) # if it is video segments - if (obj == NDSegments or obj == NDDicomSegments): - + if obj == NDSegments or obj == NDDicomSegments: first_video_annotation = annotation[0][0] args = dict( segments=annotation, data=data, name=first_video_annotation.name, feature_schema_id=first_video_annotation.feature_schema_id, - extra=first_video_annotation.extra) + extra=first_video_annotation.extra, + ) if isinstance(first_video_annotation, DICOMObjectAnnotation): group_key = first_video_annotation.group_key.value args.update(dict(group_key=group_key)) return obj.from_common(**args) - elif (obj == NDVideoMasks or obj == NDDicomMasks): + elif obj == NDVideoMasks or obj == NDDicomMasks: return obj.from_common(annotation, data) subclasses = [ @@ -725,21 +887,27 @@ def from_common( for annot in annotation.classifications ] optional_kwargs = {} - if (annotation.confidence): - optional_kwargs['confidence'] = annotation.confidence - - if (annotation.custom_metrics): - optional_kwargs['custom_metrics'] = annotation.custom_metrics - - return obj.from_common(str(annotation._uuid), annotation.value, - subclasses, annotation.name, - annotation.feature_schema_id, annotation.extra, - data, **optional_kwargs) + if annotation.confidence: + optional_kwargs["confidence"] = annotation.confidence + + if annotation.custom_metrics: + optional_kwargs["custom_metrics"] = annotation.custom_metrics + + return obj.from_common( + str(annotation._uuid), + annotation.value, + subclasses, + annotation.name, + annotation.feature_schema_id, + annotation.extra, + data, + **optional_kwargs, + ) @staticmethod def lookup_object( - annotation: Union[ObjectAnnotation, List]) -> "NDObjectType": - + annotation: Union[ObjectAnnotation, List], + ) -> "NDObjectType": if isinstance(annotation, DICOMMaskAnnotation): result = NDDicomMasks elif isinstance(annotation, VideoMaskAnnotation): @@ -772,9 +940,18 @@ def lookup_object( ) return result + NDEntityType = Union[NDConversationEntity, NDTextEntity] -NDObjectType = Union[NDLine, NDPolygon, NDPoint, NDDocumentRectangle, - NDRectangle, NDMask, NDEntityType, NDDocumentEntity] +NDObjectType = Union[ + NDLine, + NDPolygon, + NDPoint, + NDDocumentRectangle, + NDRectangle, + NDMask, + NDEntityType, + NDDocumentEntity, +] NDFrameObjectType = NDFrameRectangle, NDFramePoint, NDFrameLine NDDicomObjectType = NDDicomLine diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py index 1cdb23b76..fbea7e477 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py @@ -20,25 +20,36 @@ class NDRelationship(NDAnnotation, _SubclassRegistryBase): relationship: _Relationship @staticmethod - def to_common(annotation: "NDRelationship", source: SUPPORTED_ANNOTATIONS, - target: SUPPORTED_ANNOTATIONS) -> RelationshipAnnotation: - return RelationshipAnnotation(name=annotation.name, - value=Relationship( - source=source, - target=target, - type=Relationship.Type( - annotation.relationship.type)), - extra={'uuid': annotation.uuid}, - feature_schema_id=annotation.schema_id) + def to_common( + annotation: "NDRelationship", + source: SUPPORTED_ANNOTATIONS, + target: SUPPORTED_ANNOTATIONS, + ) -> RelationshipAnnotation: + return RelationshipAnnotation( + name=annotation.name, + value=Relationship( + source=source, + target=target, + type=Relationship.Type(annotation.relationship.type), + ), + extra={"uuid": annotation.uuid}, + feature_schema_id=annotation.schema_id, + ) @classmethod - def from_common(cls, annotation: RelationshipAnnotation, - data: Union[ImageData, TextData]) -> "NDRelationship": + def from_common( + cls, + annotation: RelationshipAnnotation, + data: Union[ImageData, TextData], + ) -> "NDRelationship": relationship = annotation.value - return cls(uuid=str(annotation._uuid), - name=annotation.name, - dataRow=DataRow(id=data.uid, global_key=data.global_key), - relationship=_Relationship( - source=str(relationship.source._uuid), - target=str(relationship.target._uuid), - type=relationship.type.value)) + return cls( + uuid=str(annotation._uuid), + name=annotation.name, + dataRow=DataRow(id=data.uid, global_key=data.global_key), + relationship=_Relationship( + source=str(relationship.source._uuid), + target=str(relationship.target._uuid), + type=relationship.type.value, + ), + ) diff --git a/libs/labelbox/src/labelbox/exceptions.py b/libs/labelbox/src/labelbox/exceptions.py index 048ca0757..34cfeaf4d 100644 --- a/libs/labelbox/src/labelbox/exceptions.py +++ b/libs/labelbox/src/labelbox/exceptions.py @@ -21,16 +21,18 @@ def __str__(self): class AuthenticationError(LabelboxError): """Raised when an API key fails authentication.""" + pass class AuthorizationError(LabelboxError): """Raised when a user is unauthorized to perform the given request.""" + pass class ResourceNotFoundError(LabelboxError): - """Exception raised when a given resource is not found. """ + """Exception raised when a given resource is not found.""" def __init__(self, db_object_type=None, params=None, message=None): """Constructor for the ResourceNotFoundException class. @@ -43,14 +45,17 @@ def __init__(self, db_object_type=None, params=None, message=None): if message is not None: super().__init__(message) else: - super().__init__("Resource '%s' not found for params: %r" % - (db_object_type.type_name(), params)) + super().__init__( + "Resource '%s' not found for params: %r" + % (db_object_type.type_name(), params) + ) self.db_object_type = db_object_type self.params = params class ResourceConflict(LabelboxError): - """Exception raised when a given resource conflicts with another. """ + """Exception raised when a given resource conflicts with another.""" + pass @@ -58,6 +63,7 @@ class ValidationFailedError(LabelboxError): """Exception raised for when a GraphQL query fails validation (query cost, etc.) E.g. a query that is too expensive, or depth is too deep. """ + pass @@ -68,25 +74,29 @@ class InternalServerError(LabelboxError): TODO: these errors need better messages from platform """ + pass class InvalidQueryError(LabelboxError): - """ Indicates a malconstructed or unsupported query (either by GraphQL in + """Indicates a malconstructed or unsupported query (either by GraphQL in general or by Labelbox specifically). This can be the result of either client - or server side query validation. """ + or server side query validation.""" + pass class UnprocessableEntityError(LabelboxError): - """ Indicates that a resource could not be created in the server side + """Indicates that a resource could not be created in the server side due to a validation or transaction error""" + pass class ResourceCreationError(LabelboxError): - """ Indicates that a resource could not be created in the server side + """Indicates that a resource could not be created in the server side due to a validation or transaction error""" + pass @@ -100,33 +110,39 @@ def __init__(self, cause): class TimeoutError(LabelboxError): """Raised when a request times-out.""" + pass class InvalidAttributeError(LabelboxError): - """ Raised when a field (name or Field instance) is not valid or found - for a specific DB object type. """ + """Raised when a field (name or Field instance) is not valid or found + for a specific DB object type.""" def __init__(self, db_object_type, field): - super().__init__("Field(s) '%r' not valid on DB type '%s'" % - (field, db_object_type.type_name())) + super().__init__( + "Field(s) '%r' not valid on DB type '%s'" + % (field, db_object_type.type_name()) + ) self.db_object_type = db_object_type self.field = field class ApiLimitError(LabelboxError): - """ Raised when the user performs too many requests in a short period - of time. """ + """Raised when the user performs too many requests in a short period + of time.""" + pass class MalformedQueryException(Exception): - """ Raised when the user submits a malformed query.""" + """Raised when the user submits a malformed query.""" + pass class UuidError(LabelboxError): - """ Raised when there are repeat Uuid's in bulk import request.""" + """Raised when there are repeat Uuid's in bulk import request.""" + pass @@ -136,16 +152,19 @@ class InconsistentOntologyException(Exception): class MALValidationError(LabelboxError): """Raised when user input is invalid for MAL imports.""" + pass class OperationNotAllowedException(Exception): """Raised when user does not have permissions to a resource or has exceeded usage limit""" + pass class OperationNotSupportedException(Exception): """Raised when sdk does not support requested operation""" + pass diff --git a/libs/labelbox/src/labelbox/orm/comparison.py b/libs/labelbox/src/labelbox/orm/comparison.py index 91c226652..7830549ea 100644 --- a/libs/labelbox/src/labelbox/orm/comparison.py +++ b/libs/labelbox/src/labelbox/orm/comparison.py @@ -1,4 +1,5 @@ from enum import Enum, auto + """ Classes for defining the client-side comparison operations used for filtering data in fetches. Intended for use by library internals and not by the end user. @@ -6,7 +7,7 @@ class LogicalExpressionComponent: - """ Implements bitwise logical operator methods (&, | and ~) so they + """Implements bitwise logical operator methods (&, | and ~) so they return a LogicalExpression object containing this LogicalExpressionComponent. """ @@ -26,22 +27,23 @@ def __invert__(self): class LogicalExpression(LogicalExpressionComponent): - """ A unary (NOT) or binary (AND, OR) logical expression between - Comparison or LogicalExpression objects. """ + """A unary (NOT) or binary (AND, OR) logical expression between + Comparison or LogicalExpression objects.""" class Op(Enum): - """ Type of logical operation. """ + """Type of logical operation.""" + AND = auto() OR = auto() NOT = auto() def __call__(self, first, second=None): - """ Forwards to LogicalExpression constructor, passing `self` - as the `op` argument. """ + """Forwards to LogicalExpression constructor, passing `self` + as the `op` argument.""" return LogicalExpression(self, first, second) def __init__(self, op, first, second=None): - """ LogicalExpression constructor. + """LogicalExpression constructor. Args: op (LogicalExpression.Op): The type of logical operation. @@ -54,12 +56,14 @@ def __init__(self, op, first, second=None): def __eq__(self, other): return self.op == other.op and ( - (self.first == other.first and self.second == other.second) or - (self.first == other.second and self.second == other.first)) + (self.first == other.first and self.second == other.second) + or (self.first == other.second and self.second == other.first) + ) def __hash__(self): - return hash( - self.op) + 2833 * hash(self.first) + 2837 * hash(self.second) + return ( + hash(self.op) + 2833 * hash(self.first) + 2837 * hash(self.second) + ) def __repr__(self): return "%r %s %r" % (self.first, self.op.name, self.second) @@ -69,11 +73,12 @@ def __str__(self): class Comparison(LogicalExpressionComponent): - """ A comparison between a database value (represented by a - `labelbox.schema.Field` object) and a constant value. """ + """A comparison between a database value (represented by a + `labelbox.schema.Field` object) and a constant value.""" class Op(Enum): - """ Type of the comparison operation. """ + """Type of the comparison operation.""" + EQ = auto() NE = auto() LT = auto() @@ -82,12 +87,12 @@ class Op(Enum): GE = auto() def __call__(self, *args): - """ Forwards to Comparison constructor, passing `self` - as the `op` argument. """ + """Forwards to Comparison constructor, passing `self` + as the `op` argument.""" return Comparison(self, *args) def __init__(self, op, field, value): - """ Comparison constructor. + """Comparison constructor. Args: op (Comparison.Op): The type of comparison. @@ -99,8 +104,11 @@ def __init__(self, op, field, value): self.value = value def __eq__(self, other): - return self.op == other.op and \ - self.field == other.field and self.value == other.value + return ( + self.op == other.op + and self.field == other.field + and self.value == other.value + ) def __hash__(self): return hash(self.op) + 2861 * hash(self.field) + 2927 * hash(self.value) diff --git a/libs/labelbox/src/labelbox/orm/db_object.py b/libs/labelbox/src/labelbox/orm/db_object.py index c4f87eac5..b210a8a5b 100644 --- a/libs/labelbox/src/labelbox/orm/db_object.py +++ b/libs/labelbox/src/labelbox/orm/db_object.py @@ -5,7 +5,11 @@ import json from labelbox import utils -from labelbox.exceptions import InvalidQueryError, InvalidAttributeError, OperationNotSupportedException +from labelbox.exceptions import ( + InvalidQueryError, + InvalidAttributeError, + OperationNotSupportedException, +) from labelbox.orm import query from labelbox.orm.model import Field, Relationship, Entity from labelbox.pagination import PaginatedCollection @@ -14,7 +18,7 @@ class DbObject(Entity): - """ A client-side representation of a database object (row). Intended as + """A client-side representation of a database object (row). Intended as base class for classes representing concrete database types (for example a Project). Exposes support functionalities so that the concrete subclass definition be as simple and DRY as possible. It should come down to just @@ -35,7 +39,7 @@ class DbObject(Entity): """ def __init__(self, client, field_values): - """ Constructor of a database object. Generally it should only be used + """Constructor of a database object. Generally it should only be used by library internals and not by the end user. Args: @@ -49,12 +53,16 @@ def __init__(self, client, field_values): value = field_values.get(utils.camel_case(relationship.name)) if relationship.cache and value is None: raise KeyError( - f"Expected field values for {relationship.name}") - setattr(self, relationship.name, - RelationshipManager(self, relationship, value)) + f"Expected field values for {relationship.name}" + ) + setattr( + self, + relationship.name, + RelationshipManager(self, relationship, value), + ) def _set_field_values(self, field_values): - """ Sets field values on this object. Ensures proper value conversions. + """Sets field values on this object. Ensures proper value conversions. Args: field_values (dict): Maps field names (GraphQL variant, snakeCase) to values. *Must* contain all field values for this object's @@ -69,7 +77,10 @@ def _set_field_values(self, field_values): except ValueError: logger.warning( "Failed to convert value '%s' to datetime for " - "field %s", value, field) + "field %s", + value, + field, + ) elif isinstance(field.field_type, Field.EnumType): value = field.field_type.enum_cls(value) elif isinstance(field.field_type, Field.ListType): @@ -80,7 +91,9 @@ def _set_field_values(self, field_values): except ValueError: logger.warning( "Failed to convert value '%s' to metadata for field %s", - value, field) + value, + field, + ) setattr(self, field.name, value) def __repr__(self): @@ -94,29 +107,34 @@ def __str__(self): attribute_values = { field.name: getattr(self, field.name) for field in self.fields() } - return "<%s %s>" % (self.type_name().split(".")[-1], - json.dumps(attribute_values, indent=4, default=str)) + return "<%s %s>" % ( + self.type_name().split(".")[-1], + json.dumps(attribute_values, indent=4, default=str), + ) def __eq__(self, other): - return (isinstance(other, DbObject) and - self.type_name() == other.type_name() and self.uid == other.uid) + return ( + isinstance(other, DbObject) + and self.type_name() == other.type_name() + and self.uid == other.uid + ) def __hash__(self): return 7541 * hash(self.type_name()) + hash(self.uid) class RelationshipManager: - """ Manages relationships (object fetching and updates) for a `DbObject` + """Manages relationships (object fetching and updates) for a `DbObject` instance. There is one RelationshipManager for each relationship in each `DbObject` instance. """ def __init__(self, source, relationship, value=None): """Args: - source (DbObject subclass instance): The object that's the source - of the relationship. - relationship (labelbox.schema.Relationship): The relationship - schema descriptor object. + source (DbObject subclass instance): The object that's the source + of the relationship. + relationship (labelbox.schema.Relationship): The relationship + schema descriptor object. """ self.source = source self.relationship = relationship @@ -127,8 +145,8 @@ def __init__(self, source, relationship, value=None): self.config = relationship.config def __call__(self, *args, **kwargs): - """ Forwards the call to either `_to_many` or `_to_one` methods, - depending on relationship type. """ + """Forwards the call to either `_to_many` or `_to_one` methods, + depending on relationship type.""" if self.relationship.deprecation_warning: logger.warning(self.relationship.deprecation_warning) @@ -139,7 +157,7 @@ def __call__(self, *args, **kwargs): return self._to_one(*args, **kwargs) def _to_many(self, where=None, order_by=None): - """ Returns an iterable over the destination relationship objects. + """Returns an iterable over the destination relationship objects. Args: where (None, Comparison or LogicalExpression): Filtering clause. order_by (None or (Field, Field.Order)): Ordering clause. @@ -149,27 +167,35 @@ def _to_many(self, where=None, order_by=None): rel = self.relationship if where is not None and not self.supports_filtering: raise InvalidQueryError( - "Relationship %s.%s doesn't support filtering" % - (self.source.type_name(), rel.name)) + "Relationship %s.%s doesn't support filtering" + % (self.source.type_name(), rel.name) + ) if order_by is not None and not self.supports_sorting: raise InvalidQueryError( - "Relationship %s.%s doesn't support sorting" % - (self.source.type_name(), rel.name)) + "Relationship %s.%s doesn't support sorting" + % (self.source.type_name(), rel.name) + ) if rel.filter_deleted: not_deleted = rel.destination_type.deleted == False where = not_deleted if where is None else where & not_deleted query_string, params = query.relationship( - self.source if self.filter_on_id else type(self.source), rel, where, - order_by) + self.source if self.filter_on_id else type(self.source), + rel, + where, + order_by, + ) return PaginatedCollection( - self.source.client, query_string, params, + self.source.client, + query_string, + params, [utils.camel_case(self.source.type_name()), rel.graphql_name], - rel.destination_type) + rel.destination_type, + ) def _to_one(self): - """ Returns the relationship destination object. """ + """Returns the relationship destination object.""" rel = self.relationship if self.value: @@ -178,7 +204,8 @@ def _to_one(self): query_string, params = query.relationship(self.source, rel, None, None) result = self.source.client.execute(query_string, params) result = result and result.get( - utils.camel_case(type(self.source).type_name())) + utils.camel_case(type(self.source).type_name()) + ) result = result and result.get(rel.graphql_name) if result is None: return None @@ -186,26 +213,28 @@ def _to_one(self): return rel.destination_type(self.source.client, result) def connect(self, other): - """ Connects source object of this manager to the `other` object. """ + """Connects source object of this manager to the `other` object.""" query_string, params = query.update_relationship( - self.source, other, self.relationship, "connect") + self.source, other, self.relationship, "connect" + ) self.source.client.execute(query_string, params) def disconnect(self, other): - """ Disconnects source object of this manager from the `other` object. """ + """Disconnects source object of this manager from the `other` object.""" if not self.config.disconnect_supported: raise OperationNotSupportedException( - "Disconnect is not supported for this relationship") + "Disconnect is not supported for this relationship" + ) query_string, params = query.update_relationship( - self.source, other, self.relationship, "disconnect") + self.source, other, self.relationship, "disconnect" + ) self.source.client.execute(query_string, params) class Updateable: - def update(self, **kwargs): - """ Updates this DB object with new values. Values should be + """Updates this DB object with new values. Values should be passed as key-value arguments with field names as keys: >>> db_object.update(name="New name", title="A title") @@ -229,10 +258,10 @@ def update(self, **kwargs): class Deletable: - """ Implements deletion for objects that have a `deleted` attribute. """ + """Implements deletion for objects that have a `deleted` attribute.""" def delete(self): - """ Deletes this DB object from the DB (server side). After + """Deletes this DB object from the DB (server side). After a call to this you should not use this DB object anymore. """ query_string, params = query.delete(self) @@ -240,7 +269,7 @@ def delete(self): class BulkDeletable: - """ Implements deletion for objects that have a custom, bulk deletion + """Implements deletion for objects that have a custom, bulk deletion mutation (accepts a list of IDs of objects to be deleted). A subclass must override the `bulk_delete` static method so it @@ -263,13 +292,14 @@ def _bulk_delete(objects, use_where_clause): types = {type(o) for o in objects} if len(types) != 1: raise InvalidQueryError( - "Can't bulk-delete objects of different types: %r" % types) + "Can't bulk-delete objects of different types: %r" % types + ) query_str, params = query.bulk_delete(objects, use_where_clause) objects[0].client.execute(query_str, params) def delete(self): - """ Deletes this DB object from the DB (server side). After + """Deletes this DB object from the DB (server side). After a call to this you should not use this DB object anymore. """ type(self).bulk_delete([self]) @@ -295,7 +325,8 @@ def wrapper(*args, **kwargs): else: raise ValueError( f"Static method {fn.__name__} must have a client passed in as the first " - f"argument or as a keyword argument.") + f"argument or as a keyword argument." + ) wrapped_fn = fn.__func__ else: client = args[0].client @@ -306,7 +337,8 @@ def wrapper(*args, **kwargs): f"This function {fn.__name__} relies on a experimental feature in the api. " f"This means that the interface could change. " f"Set `enable_experimental=True` in the client to enable use of " - f"experimental functions.") + f"experimental functions." + ) return wrapped_fn(*args, **kwargs) return wrapper diff --git a/libs/labelbox/src/labelbox/orm/model.py b/libs/labelbox/src/labelbox/orm/model.py index 5720b67cc..84dcac774 100644 --- a/libs/labelbox/src/labelbox/orm/model.py +++ b/libs/labelbox/src/labelbox/orm/model.py @@ -6,13 +6,14 @@ from labelbox import utils from labelbox.exceptions import InvalidAttributeError from labelbox.orm.comparison import Comparison + """ Defines Field, Relationship and Entity. These classes are building blocks for defining the Labelbox schema, DB object operations and queries. """ class Field: - """ Represents a field in a database table. A Field has a name, a type + """Represents a field in a database table. A Field has a name, a type (corresponds to server-side GraphQL type) and a server-side name. The server-side name is most often just a camelCase version of the client-side snake_case name. @@ -48,7 +49,6 @@ class Type(Enum): Json = auto() class EnumType: - def __init__(self, enum_cls: type): self.enum_cls = enum_cls @@ -57,7 +57,7 @@ def name(self): return self.enum_cls.__name__ class ListType: - """ Represents Field that is a list of some object. + """Represents Field that is a list of some object. Args: list_cls (type): Type of object that list is made of. graphql_type (str): Inner object's graphql type. @@ -76,7 +76,8 @@ def name(self): return f"[{self.graphql_type}]" class Order(Enum): - """ Type of sort ordering. """ + """Type of sort ordering.""" + Asc = auto() Desc = auto() @@ -116,12 +117,14 @@ def Json(*args): def List(list_cls: type, graphql_type=None, **kwargs): return Field(Field.ListType(list_cls, graphql_type), **kwargs) - def __init__(self, - field_type: Union[Type, EnumType, ListType], - name, - graphql_name=None, - result_subquery=None): - """ Field init. + def __init__( + self, + field_type: Union[Type, EnumType, ListType], + name, + graphql_name=None, + result_subquery=None, + ): + """Field init. Args: field_type (Field.Type): The type of the field. name (str): client-side Python attribute name of a database @@ -140,7 +143,7 @@ def __init__(self, @property def asc(self): - """ Property that resolves to tuple (Field, Field.Order). + """Property that resolves to tuple (Field, Field.Order). Used for easy definition of sort ordering: >>> projects_ordered = client.get_projects(order_by=Project.name.asc) """ @@ -148,14 +151,14 @@ def asc(self): @property def desc(self): - """ Property that resolves to tuple (Field, Field.Order). + """Property that resolves to tuple (Field, Field.Order). Used for easy definition of sort ordering: >>> projects_ordered = client.get_projects(order_by=Project.name.desc) """ return (self, Field.Order.Desc) def __eq__(self, other): - """ Equality of Fields has two meanings. If comparing to a Field object, + """Equality of Fields has two meanings. If comparing to a Field object, then a boolean indicator if the fields are identical is returned. If comparing to any other type, a Comparison object is created. """ @@ -165,7 +168,7 @@ def __eq__(self, other): return Comparison.Op.EQ(self, other) def __ne__(self, other): - """ Equality of Fields has two meanings. If comparing to a Field object, + """Equality of Fields has two meanings. If comparing to a Field object, then a boolean indicator if the fields are identical is returned. If comparing to any other type, a Comparison object is created. """ @@ -199,7 +202,7 @@ def __repr__(self): class Relationship: - """ Represents a relationship in a database table. + """Represents a relationship in a database table. Attributes: relationship_type (Relationship.Type): Indicator if to-one or to-many @@ -236,15 +239,17 @@ def ToOne(*args, **kwargs): def ToMany(*args, **kwargs): return Relationship(Relationship.Type.ToMany, *args, **kwargs) - def __init__(self, - relationship_type, - destination_type_name, - filter_deleted=True, - name=None, - graphql_name=None, - cache=False, - deprecation_warning=None, - config=Config()): + def __init__( + self, + relationship_type, + destination_type_name, + filter_deleted=True, + name=None, + graphql_name=None, + cache=False, + deprecation_warning=None, + config=Config(), + ): self.relationship_type = relationship_type self.destination_type_name = destination_type_name self.filter_deleted = filter_deleted @@ -254,7 +259,8 @@ def __init__(self, if name is None: name = utils.snake_case(destination_type_name) + ( - "s" if relationship_type == Relationship.Type.ToMany else "") + "s" if relationship_type == Relationship.Type.ToMany else "" + ) self.name = name if graphql_name is None: @@ -273,10 +279,11 @@ def __repr__(self): class EntityMeta(type): - """ Entity metaclass. Registers Entity subclasses as attributes + """Entity metaclass. Registers Entity subclasses as attributes of the Entity class object so they can be referenced for example like: Entity.Project. """ + # Maps Entity name to Relationships for all currently defined Entities relationship_mappings: Dict[str, List[Relationship]] = {} @@ -288,14 +295,16 @@ def __init__(cls, clsname, superclasses, attributedict): cls.validate_cached_relationships() if clsname != "Entity": setattr(Entity, clsname, cls) - EntityMeta.relationship_mappings[utils.snake_case( - cls.__name__)] = cls.relationships() + EntityMeta.relationship_mappings[utils.snake_case(cls.__name__)] = ( + cls.relationships() + ) @staticmethod def raise_for_nested_cache(first: str, middle: str, last: List[str]): raise TypeError( "Cannot cache a relationship to an Entity with its own cached relationship(s). " - f"`{first}` caches `{middle}` which caches `{last}`") + f"`{first}` caches `{middle}` which caches `{last}`" + ) @staticmethod def cached_entities(entity_name: str): @@ -329,8 +338,11 @@ def validate_cached_relationships(cls): for rel in cached_rels: nested = cls.cached_entities(rel.name) if nested: - cls.raise_for_nested_cache(utils.snake_case(cls.__name__), - rel.name, list(nested.keys())) + cls.raise_for_nested_cache( + utils.snake_case(cls.__name__), + rel.name, + list(nested.keys()), + ) # If the current Entity (cls) has any cached relationships (cached_rels) # then no other defined Entity (entities in EntityMeta.relationship_mappings) can cache this Entity. @@ -347,12 +359,13 @@ def validate_cached_relationships(cls): cls.raise_for_nested_cache( utils.snake_case(entity_name), utils.snake_case(cls.__name__), - [entity.name for entity in cached_rels]) + [entity.name for entity in cached_rels], + ) class Entity(metaclass=EntityMeta): - """ An entity that contains fields and relationships. Base class - for DbObject (which is base class for concrete schema classes). """ + """An entity that contains fields and relationships. Base class + for DbObject (which is base class for concrete schema classes).""" # Every Entity has an "id" and a "deleted" field # Name the "id" field "uid" in Python to avoid conflict with keyword. @@ -392,7 +405,7 @@ class Entity(metaclass=EntityMeta): @classmethod def _attributes_of_type(cls, attr_type): - """ Yields all the attributes in `cls` of the given `attr_type`. """ + """Yields all the attributes in `cls` of the given `attr_type`.""" for attr_name in dir(cls): attr = getattr(cls, attr_name) if isinstance(attr, attr_type): @@ -400,7 +413,7 @@ def _attributes_of_type(cls, attr_type): @classmethod def fields(cls): - """ Returns a generator that yields all the Fields declared in a + """Returns a generator that yields all the Fields declared in a concrete subclass. """ for attr in cls._attributes_of_type(Field): @@ -409,14 +422,14 @@ def fields(cls): @classmethod def relationships(cls): - """ Returns a generator that yields all the Relationships declared in + """Returns a generator that yields all the Relationships declared in a concrete subclass. """ return cls._attributes_of_type(Relationship) @classmethod def field(cls, field_name): - """ Returns a Field object for the given name. + """Returns a Field object for the given name. Args: field_name (str): Field name, Python (snake-case) convention. Return: @@ -432,7 +445,7 @@ def field(cls, field_name): @classmethod def attribute(cls, attribute_name): - """ Returns a Field or a Relationship object for the given name. + """Returns a Field or a Relationship object for the given name. Args: attribute_name (str): Field or Relationship name, Python (snake-case) convention. @@ -449,7 +462,7 @@ def attribute(cls, attribute_name): @classmethod def type_name(cls): - """ Returns this DB object type name in TitleCase. For example: - Project, DataRow, ... + """Returns this DB object type name in TitleCase. For example: + Project, DataRow, ... """ return cls.__name__.split(".")[-1] diff --git a/libs/labelbox/src/labelbox/orm/query.py b/libs/labelbox/src/labelbox/orm/query.py index f28714d09..8fa9fea00 100644 --- a/libs/labelbox/src/labelbox/orm/query.py +++ b/libs/labelbox/src/labelbox/orm/query.py @@ -2,14 +2,19 @@ from typing import Any, Dict from labelbox import utils -from labelbox.exceptions import InvalidQueryError, InvalidAttributeError, MalformedQueryException +from labelbox.exceptions import ( + InvalidQueryError, + InvalidAttributeError, + MalformedQueryException, +) from labelbox.orm.comparison import LogicalExpression, Comparison from labelbox.orm.model import Field, Relationship, Entity + """ Common query creation functionality. """ def format_param_declaration(params): - """ Formats the parameters dictionary into a declaration of GraphQL + """Formats the parameters dictionary into a declaration of GraphQL query parameters. Args: @@ -27,12 +32,18 @@ def attr_type(attr): else: return Field.Type.ID.name - return "(" + ", ".join("$%s: %s!" % (param, attr_type(attr)) - for param, (_, attr) in params.items()) + ")" + return ( + "(" + + ", ".join( + "$%s: %s!" % (param, attr_type(attr)) + for param, (_, attr) in params.items() + ) + + ")" + ) def results_query_part(entity): - """ Generates the results part of the query. The results contain + """Generates the results part of the query. The results contain all the entity's fields as well as prefetched relationships. Note that this is a recursive function. If there is a cycle in the @@ -44,30 +55,30 @@ def results_query_part(entity): # Query for fields fields = [ field.result_subquery - if field.result_subquery is not None else field.graphql_name + if field.result_subquery is not None + else field.graphql_name for field in entity.fields() ] # Query for cached relationships - fields.extend([ - Query(rel.graphql_name, rel.destination_type).format()[0] - for rel in entity.relationships() - if rel.cache - ]) + fields.extend( + [ + Query(rel.graphql_name, rel.destination_type).format()[0] + for rel in entity.relationships() + if rel.cache + ] + ) return " ".join(fields) class Query: - """ A data structure used during the construction of a query. Supports - subquery (also Query object) nesting for relationship. """ - - def __init__(self, - what, - subquery, - where=None, - paginate=False, - order_by=None): - """ Initializer. + """A data structure used during the construction of a query. Supports + subquery (also Query object) nesting for relationship.""" + + def __init__( + self, what, subquery, where=None, paginate=False, order_by=None + ): + """Initializer. Args: what (str): What is being queried. Typically an object type in singular or plural (i.e. "project" or "projects"). @@ -88,7 +99,7 @@ def __init__(self, self.order_by = order_by def format_subquery(self): - """ Formats the subquery (a Query or Entity subtype). """ + """Formats the subquery (a Query or Entity subtype).""" if isinstance(self.subquery, Query): return self.subquery.format() elif issubclass(self.subquery, Entity): @@ -97,14 +108,14 @@ def format_subquery(self): raise MalformedQueryException() def format_clauses(self, params): - """ Formats the where, order_by and pagination clauses. + """Formats the where, order_by and pagination clauses. Args: params (dict): The current parameter dictionary. """ def format_where(node): - """ Helper that resursively constructs a where clause from a - LogicalExpression tree (leaf nodes are Comparisons). """ + """Helper that resursively constructs a where clause from a + LogicalExpression tree (leaf nodes are Comparisons).""" COMPARISON_TO_SUFFIX = { Comparison.Op.EQ: "", Comparison.Op.NE: "_not", @@ -117,23 +128,29 @@ def format_where(node): if isinstance(node, Comparison): param_name = "param_%d" % len(params) params[param_name] = (node.value, node.field) - return "{%s%s: $%s}" % (node.field.graphql_name, - COMPARISON_TO_SUFFIX[node.op], - param_name) + return "{%s%s: $%s}" % ( + node.field.graphql_name, + COMPARISON_TO_SUFFIX[node.op], + param_name, + ) if node.op == LogicalExpression.Op.NOT: return "{NOT: [%s]}" % format_where(node.first) - return "{%s: [%s, %s]}" % (node.op.name.upper(), - format_where(node.first), - format_where(node.second)) + return "{%s: [%s, %s]}" % ( + node.op.name.upper(), + format_where(node.first), + format_where(node.second), + ) paginate = "skip: %d first: %d" if self.paginate else "" where = "where: %s" % format_where(self.where) if self.where else "" if self.order_by: - order_by = "orderBy: %s_%s" % (self.order_by[0].graphql_name, - self.order_by[1].name.upper()) + order_by = "orderBy: %s_%s" % ( + self.order_by[0].graphql_name, + self.order_by[1].name.upper(), + ) else: order_by = "" @@ -141,7 +158,7 @@ def format_where(node): return "(" + clauses + ")" if clauses else "" def format(self): - """ Formats the full query but without "query" prefix, query name + """Formats the full query but without "query" prefix, query name and parameter declaration. Return: (str, dict) tuple. str is the query and dict maps parameter @@ -153,7 +170,7 @@ def format(self): return query, params def format_top(self, name): - """ Formats the full query including "query" prefix, query name + """Formats the full query including "query" prefix, query name and parameter declaration. The result of this function can be sent to the Client object for execution. @@ -171,7 +188,7 @@ def format_top(self, name): def get_single(entity, uid): - """ Constructs the query and params dict for obtaining a single object. Either + """Constructs the query and params dict for obtaining a single object. Either on ID, or without params. Args: entity (type): An Entity subtype being obtained. @@ -181,12 +198,13 @@ def get_single(entity, uid): """ type_name = entity.type_name() where = entity.uid == uid if uid else None - return Query(utils.camel_case(type_name), entity, - where).format_top("Get" + type_name) + return Query(utils.camel_case(type_name), entity, where).format_top( + "Get" + type_name + ) def logical_ops(where): - """ Returns a generator that yields all the logical operator + """Returns a generator that yields all the logical operator type objects (`LogicalExpression.Op` instances) from a where clause. @@ -203,7 +221,7 @@ def logical_ops(where): def check_where_clause(entity, where): - """ Checks the `where` clause of a query. A `where` clause is legal + """Checks the `where` clause of a query. A `where` clause is legal if it only refers to fields found in the entity it's defined for. Since only AND logical operations are supported server-side at the moment, logical OR and NOT are illegal. @@ -217,7 +235,7 @@ def check_where_clause(entity, where): """ def fields(where): - """ Yields all the fields in a `where` clause. """ + """Yields all the fields in a `where` clause.""" if isinstance(where, LogicalExpression): for f in chain(fields(where.first), fields(where.second)): yield f @@ -233,15 +251,18 @@ def fields(where): if len(set(where_fields)) != len(where_fields): raise InvalidQueryError( "Where clause contains multiple comparisons for " - "the same field: %r." % where) + "the same field: %r." % where + ) if set(logical_ops(where)) not in (set(), {LogicalExpression.Op.AND}): - raise InvalidQueryError("Currently only AND logical ops are allowed in " - "the where clause of a query.") + raise InvalidQueryError( + "Currently only AND logical ops are allowed in " + "the where clause of a query." + ) def check_order_by_clause(entity, order_by): - """ Checks that the `order_by` clause field is a part of `entity`. + """Checks that the `order_by` clause field is a part of `entity`. Args: entity (type): An Entity subclass type. @@ -257,7 +278,7 @@ def check_order_by_clause(entity, order_by): def get_all(entity, where): - """ Constructs a query that fetches all items of the given type. The + """Constructs a query that fetches all items of the given type. The resulting query is intended to be used for pagination, it contains two python-string int-placeholders (%d) for 'skip' and 'first' pagination parameters. @@ -276,7 +297,7 @@ def get_all(entity, where): def relationship(source, relationship, where, order_by): - """ Constructs a query that fetches all items from a -to-many + """Constructs a query that fetches all items from a -to-many relationship. To be used like: >>> project = ... >>> query_str, params = relationship(Project, "datasets", Dataset) @@ -304,17 +325,24 @@ def relationship(source, relationship, where, order_by): check_where_clause(relationship.destination_type, where) check_order_by_clause(relationship.destination_type, order_by) to_many = relationship.relationship_type == Relationship.Type.ToMany - subquery = Query(relationship.graphql_name, relationship.destination_type, - where, to_many, order_by) - query_where = type(source).uid == source.uid if isinstance(source, Entity) \ - else None + subquery = Query( + relationship.graphql_name, + relationship.destination_type, + where, + to_many, + order_by, + ) + query_where = ( + type(source).uid == source.uid if isinstance(source, Entity) else None + ) query = Query(utils.camel_case(source.type_name()), subquery, query_where) - return query.format_top("Get" + source.type_name() + - utils.title_case(relationship.graphql_name)) + return query.format_top( + "Get" + source.type_name() + utils.title_case(relationship.graphql_name) + ) def create(entity, data): - """ Generates a query and parameters for creating a new DB object. + """Generates a query and parameters for creating a new DB object. Args: entity (type): An Entity subtype indicating which kind of @@ -330,8 +358,10 @@ def format_param_value(attribute, param): if isinstance(attribute, Field): return "%s: $%s" % (attribute.graphql_name, param) else: - return "%s: {connect: {id: $%s}}" % (utils.camel_case( - attribute.graphql_name), param) + return "%s: {connect: {id: $%s}}" % ( + utils.camel_case(attribute.graphql_name), + param, + ) # Convert data to params params = { @@ -339,16 +369,21 @@ def format_param_value(attribute, param): } query_str = """mutation Create%sPyApi%s{create%s(data: {%s}) {%s}} """ % ( - type_name, format_param_declaration(params), type_name, " ".join( + type_name, + format_param_declaration(params), + type_name, + " ".join( format_param_value(attribute, param) - for param, (_, attribute) in params.items()), - results_query_part(entity)) + for param, (_, attribute) in params.items() + ), + results_query_part(entity), + ) return query_str, {name: value for name, (value, _) in params.items()} def update_relationship(a, b, relationship, update): - """ Updates the relationship in DB object `a` to connect or disconnect + """Updates the relationship in DB object `a` to connect or disconnect DB object `b`. Args: @@ -360,8 +395,10 @@ def update_relationship(a, b, relationship, update): Return: (query_string, query_parameters) """ - to_one_disconnect = update == "disconnect" and \ - relationship.relationship_type == Relationship.Type.ToOne + to_one_disconnect = ( + update == "disconnect" + and relationship.relationship_type == Relationship.Type.ToOne + ) a_uid_param = utils.camel_case(type(a).type_name()) + "Id" @@ -375,9 +412,16 @@ def update_relationship(a, b, relationship, update): query_str = """mutation %s%sAnd%sPyApi%s{update%s( where: {id: $%s} data: {%s: {%s: %s}}) {id}} """ % ( - utils.title_case(update), type(a).type_name(), type(b).type_name(), - param_declr, utils.title_case(type(a).type_name()), a_uid_param, - relationship.graphql_name, update, b_query) + utils.title_case(update), + type(a).type_name(), + type(b).type_name(), + param_declr, + utils.title_case(type(a).type_name()), + a_uid_param, + relationship.graphql_name, + update, + b_query, + ) if to_one_disconnect: params = {a_uid_param: a.uid} @@ -388,7 +432,7 @@ def update_relationship(a, b, relationship, update): def update_fields(db_object, values): - """ Creates a query that updates `db_object` fields with the + """Creates a query that updates `db_object` fields with the given values. Args: @@ -400,8 +444,10 @@ def update_fields(db_object, values): """ type_name = db_object.type_name() id_param = "%sId" % type_name - values_str = " ".join("%s: $%s" % (field.graphql_name, field.graphql_name) - for field, _ in values.items()) + values_str = " ".join( + "%s: $%s" % (field.graphql_name, field.graphql_name) + for field, _ in values.items() + ) params = { field.graphql_name: (value, field) for field, value in values.items() } @@ -409,14 +455,19 @@ def update_fields(db_object, values): query_str = """mutation update%sPyApi%s{update%s( where: {id: $%s} data: {%s}) {%s}} """ % ( - utils.title_case(type_name), format_param_declaration(params), - type_name, id_param, values_str, results_query_part(type(db_object))) + utils.title_case(type_name), + format_param_declaration(params), + type_name, + id_param, + values_str, + results_query_part(type(db_object)), + ) return query_str, {name: value for name, (value, _) in params.items()} def delete(db_object): - """ Generates a query that deletes the given `db_object` from the DB. + """Generates a query that deletes the given `db_object` from the DB. Args: db_object (DbObject): The DB object being deleted. @@ -424,14 +475,17 @@ def delete(db_object): id_param = "%sId" % db_object.type_name() query_str = """mutation delete%sPyApi%s{update%s( where: {id: $%s} data: {deleted: true}) {id}} """ % ( - db_object.type_name(), "($%s: ID!)" % id_param, db_object.type_name(), - id_param) + db_object.type_name(), + "($%s: ID!)" % id_param, + db_object.type_name(), + id_param, + ) return query_str, {id_param: db_object.uid} def bulk_delete(db_objects, use_where_clause): - """ Generates a query that bulk-deletes the given `db_objects` from the + """Generates a query that bulk-deletes the given `db_objects` from the DB. Args: @@ -441,13 +495,17 @@ def bulk_delete(db_objects, use_where_clause): """ type_name = db_objects[0].type_name() if use_where_clause: - query_str = "mutation delete%ssPyApi{delete%ss(where: {%sIds: [%s]}){id}}" + query_str = ( + "mutation delete%ssPyApi{delete%ss(where: {%sIds: [%s]}){id}}" + ) else: query_str = "mutation delete%ssPyApi{delete%ss(%sIds: [%s]){id}}" query_str = query_str % ( - utils.title_case(type_name), utils.title_case(type_name), - utils.camel_case(type_name), ", ".join( - '"%s"' % db_object.uid for db_object in db_objects)) + utils.title_case(type_name), + utils.title_case(type_name), + utils.camel_case(type_name), + ", ".join('"%s"' % db_object.uid for db_object in db_objects), + ) return query_str, {} diff --git a/libs/labelbox/src/labelbox/pagination.py b/libs/labelbox/src/labelbox/pagination.py index a173505c9..a3b170ec7 100644 --- a/libs/labelbox/src/labelbox/pagination.py +++ b/libs/labelbox/src/labelbox/pagination.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from typing import TYPE_CHECKING + if TYPE_CHECKING: from labelbox import Client from labelbox.orm.db_object import DbObject @@ -11,7 +12,7 @@ class PaginatedCollection: - """ An iterable collection of database objects (Projects, Labels, etc...). + """An iterable collection of database objects (Projects, Labels, etc...). Implements automatic (transparent to the user) paginated fetching during iteration. Intended for use by library internals and not by the end user. @@ -19,15 +20,17 @@ class PaginatedCollection: __init__ map exactly to object attributes. """ - def __init__(self, - client: "Client", - query: str, - params: Dict[str, Union[str, int]], - dereferencing: Union[List[str], Dict[str, Any]], - obj_class: Union[Type["DbObject"], Callable[[Any, Any], Any]], - cursor_path: Optional[List[str]] = None, - experimental: bool = False): - """ Creates a PaginatedCollection. + def __init__( + self, + client: "Client", + query: str, + params: Dict[str, Union[str, int]], + dereferencing: Union[List[str], Dict[str, Any]], + obj_class: Union[Type["DbObject"], Callable[[Any, Any], Any]], + cursor_path: Optional[List[str]] = None, + experimental: bool = False, + ): + """Creates a PaginatedCollection. Args: client (labelbox.Client): the client used for fetching data from DB. @@ -48,18 +51,19 @@ def __init__(self, self._data_ind = 0 pagination_kwargs = { - 'client': client, - 'obj_class': obj_class, - 'dereferencing': dereferencing, - 'experimental': experimental, - 'query': query, - 'params': params + "client": client, + "obj_class": obj_class, + "dereferencing": dereferencing, + "experimental": experimental, + "query": query, + "params": params, } - self.paginator = _CursorPagination( - cursor_path, ** - pagination_kwargs) if cursor_path else _OffsetPagination( - **pagination_kwargs) + self.paginator = ( + _CursorPagination(cursor_path, **pagination_kwargs) + if cursor_path + else _OffsetPagination(**pagination_kwargs) + ) def __iter__(self): self._data_ind = 0 @@ -107,10 +111,15 @@ def get_many(self, n: int): class _Pagination(ABC): - - def __init__(self, client: "Client", obj_class: Type["DbObject"], - dereferencing: Dict[str, Any], query: str, - params: Dict[str, Any], experimental: bool): + def __init__( + self, + client: "Client", + obj_class: Type["DbObject"], + dereferencing: Dict[str, Any], + query: str, + params: Dict[str, Any], + experimental: bool, + ): self.client = client self.obj_class = obj_class self.dereferencing = dereferencing @@ -125,16 +134,14 @@ def get_page_data(self, results: Dict[str, Any]) -> List["DbObject"]: return [self.obj_class(self.client, result) for result in results] @abstractmethod - def get_next_page(self) -> Tuple[Dict[str, Any], bool]: - ... + def get_next_page(self) -> Tuple[Dict[str, Any], bool]: ... class _CursorPagination(_Pagination): - def __init__(self, cursor_path: List[str], *args, **kwargs): super().__init__(*args, **kwargs) self.cursor_path = cursor_path - self.next_cursor: Optional[Any] = kwargs.get('params', {}).get('from') + self.next_cursor: Optional[Any] = kwargs.get("params", {}).get("from") def increment_page(self, results: Dict[str, Any]): for path in self.cursor_path: @@ -145,11 +152,11 @@ def fetched_all(self) -> bool: return not self.next_cursor def fetch_results(self) -> Dict[str, Any]: - page_size = self.params.get('first', _PAGE_SIZE) - self.params.update({'from': self.next_cursor, 'first': page_size}) - return self.client.execute(self.query, - self.params, - experimental=self.experimental) + page_size = self.params.get("first", _PAGE_SIZE) + self.params.update({"from": self.next_cursor, "first": page_size}) + return self.client.execute( + self.query, self.params, experimental=self.experimental + ) def get_next_page(self): results = self.fetch_results() @@ -160,7 +167,6 @@ def get_next_page(self): class _OffsetPagination(_Pagination): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._fetched_pages = 0 @@ -173,9 +179,9 @@ def fetched_all(self, n_items: int) -> bool: def fetch_results(self) -> Dict[str, Any]: query = self.query % (self._fetched_pages * _PAGE_SIZE, _PAGE_SIZE) - return self.client.execute(query, - self.params, - experimental=self.experimental) + return self.client.execute( + query, self.params, experimental=self.experimental + ) def get_next_page(self): results = self.fetch_results() diff --git a/libs/labelbox/src/labelbox/parser.py b/libs/labelbox/src/labelbox/parser.py index fab41bb81..8f64adaf4 100644 --- a/libs/labelbox/src/labelbox/parser.py +++ b/libs/labelbox/src/labelbox/parser.py @@ -2,24 +2,23 @@ class NdjsonDecoder(json.JSONDecoder): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def decode(self, s: str, *args, **kwargs): - lines = ','.join(s.splitlines()) + lines = ",".join(s.splitlines()) text = f"[{lines}]" # NOTE: this is a hack to make json.loads work for ndjson return super().decode(text, *args, **kwargs) def loads(ndjson_string, **kwargs) -> list: - kwargs.setdefault('cls', NdjsonDecoder) + kwargs.setdefault("cls", NdjsonDecoder) return json.loads(ndjson_string, **kwargs) def dumps(obj, **kwargs): lines = map(lambda obj: json.dumps(obj, **kwargs), obj) - return '\n'.join(lines) + return "\n".join(lines) def dump(obj, io, **kwargs): diff --git a/libs/labelbox/src/labelbox/schema/__init__.py b/libs/labelbox/src/labelbox/schema/__init__.py index 9f187bf87..03327e0d1 100644 --- a/libs/labelbox/src/labelbox/schema/__init__.py +++ b/libs/labelbox/src/labelbox/schema/__init__.py @@ -26,4 +26,4 @@ import labelbox.schema.identifiable import labelbox.schema.catalog import labelbox.schema.ontology_kind -import labelbox.schema.project_overview \ No newline at end of file +import labelbox.schema.project_overview diff --git a/libs/labelbox/src/labelbox/schema/annotation_import.py b/libs/labelbox/src/labelbox/schema/annotation_import.py index 2d1fd8582..df7f272a3 100644 --- a/libs/labelbox/src/labelbox/schema/annotation_import.py +++ b/libs/labelbox/src/labelbox/schema/annotation_import.py @@ -3,7 +3,16 @@ import logging import os import time -from typing import Any, BinaryIO, Dict, List, Optional, Union, TYPE_CHECKING, cast +from typing import ( + Any, + BinaryIO, + Dict, + List, + Optional, + Union, + TYPE_CHECKING, + cast, +) from collections import defaultdict from google.api_core import retry @@ -16,7 +25,9 @@ from labelbox.orm.db_object import DbObject from labelbox.orm.model import Field, Relationship from labelbox.utils import is_exactly_one_set -from labelbox.schema.confidence_presence_checker import LabelsConfidencePresenceChecker +from labelbox.schema.confidence_presence_checker import ( + LabelsConfidencePresenceChecker, +) from labelbox.schema.enums import AnnotationImportState from labelbox.schema.serialization import serialize_labels @@ -92,14 +103,14 @@ def statuses(self) -> List[Dict[str, Any]]: self.wait_until_done() return self._fetch_remote_ndjson(self.status_file_url) - def wait_till_done(self, - sleep_time_seconds: int = 10, - show_progress: bool = False) -> None: + def wait_till_done( + self, sleep_time_seconds: int = 10, show_progress: bool = False + ) -> None: self.wait_until_done(sleep_time_seconds, show_progress) - def wait_until_done(self, - sleep_time_seconds: int = 10, - show_progress: bool = False) -> None: + def wait_until_done( + self, sleep_time_seconds: int = 10, show_progress: bool = False + ) -> None: """Blocks import job until certain conditions are met. Blocks until the AnnotationImport.state changes either to `AnnotationImportState.FINISHED` or `AnnotationImportState.FAILED`, @@ -108,9 +119,14 @@ def wait_until_done(self, sleep_time_seconds (int): a time to block between subsequent API calls show_progress (bool): should show progress bar """ - pbar = tqdm(total=100, - bar_format="{n}% |{bar}| [{elapsed}, {rate_fmt}{postfix}]" - ) if show_progress else None + pbar = ( + tqdm( + total=100, + bar_format="{n}% |{bar}| [{elapsed}, {rate_fmt}{postfix}]", + ) + if show_progress + else None + ) while self.state.value == AnnotationImportState.RUNNING.value: logger.info(f"Sleeping for {sleep_time_seconds} seconds...") time.sleep(sleep_time_seconds) @@ -122,9 +138,13 @@ def wait_until_done(self, pbar.update(100 - pbar.n) pbar.close() - @retry.Retry(predicate=retry.if_exception_type( - labelbox.exceptions.ApiLimitError, labelbox.exceptions.TimeoutError, - labelbox.exceptions.NetworkError)) + @retry.Retry( + predicate=retry.if_exception_type( + labelbox.exceptions.ApiLimitError, + labelbox.exceptions.TimeoutError, + labelbox.exceptions.NetworkError, + ) + ) def __backoff_refresh(self) -> None: self.refresh() @@ -145,21 +165,24 @@ def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]: return parser.loads(response.text) @classmethod - def _create_from_bytes(cls, client, variables, query_str, file_name, - bytes_data) -> Dict[str, Any]: + def _create_from_bytes( + cls, client, variables, query_str, file_name, bytes_data + ) -> Dict[str, Any]: operations = json.dumps({"variables": variables, "query": query_str}) data = { "operations": operations, - "map": (None, json.dumps({file_name: ["variables.file"]})) + "map": (None, json.dumps({file_name: ["variables.file"]})), } file_data = (file_name, bytes_data, NDJSON_MIME_TYPE) files = {file_name: file_data} return client.execute(data=data, files=files) @classmethod - def _get_ndjson_from_objects(cls, objects: Union[List[Dict[str, Any]], - List["Label"]], - object_name: str) -> BinaryIO: + def _get_ndjson_from_objects( + cls, + objects: Union[List[Dict[str, Any]], List["Label"]], + object_name: str, + ) -> BinaryIO: if not isinstance(objects, list): raise TypeError( f"{object_name} must be in a form of list. Found {type(objects)}" @@ -173,17 +196,15 @@ def _get_ndjson_from_objects(cls, objects: Union[List[Dict[str, Any]], raise ValueError(f"{object_name} cannot be empty") return data_str.encode( - 'utf-8' + "utf-8" ) # NOTICE this method returns bytes, NOT BinaryIO... should have done io.BytesIO(...) but not going to change this at the moment since it works and fools mypy def refresh(self) -> None: - """Synchronizes values of all fields with the database. - """ + """Synchronizes values of all fields with the database.""" cls = type(self) - res = cls.from_name(self.client, - self.parent_id, - self.name, - as_json=True) + res = cls.from_name( + self.client, self.parent_id, self.name, as_json=True + ) self._set_field_values(res) @classmethod @@ -193,26 +214,29 @@ def _validate_data_rows(cls, objects: List[Dict[str, Any]]): and only one of 'id' or 'globalKey' is provided. Shows up to `max_num_errors` errors if invalidated, to prevent - large number of error messages from being printed out + large number of error messages from being printed out """ errors = [] max_num_errors = 100 labels_per_datarow: Dict[str, Dict[str, int]] = defaultdict( - lambda: defaultdict(int)) + lambda: defaultdict(int) + ) for object in objects: - if 'dataRow' not in object: + if "dataRow" not in object: errors.append(f"'dataRow' is missing in {object}") continue - data_row_object = object['dataRow'] - if not is_exactly_one_set(data_row_object.get('id'), - data_row_object.get('globalKey')): + data_row_object = object["dataRow"] + if not is_exactly_one_set( + data_row_object.get("id"), data_row_object.get("globalKey") + ): errors.append( f"Must provide only one of 'id' or 'globalKey' for 'dataRow' in {object}" ) else: data_row_id = data_row_object.get( - 'globalKey') or data_row_object.get('id') - name = object.get('name') + "globalKey" + ) or data_row_object.get("id") + name = object.get("name") if name: labels_per_datarow[data_row_id][name] += 1 for data_row_id, label_annotations in labels_per_datarow.items(): @@ -224,7 +248,7 @@ def _validate_data_rows(cls, objects: List[Dict[str, Any]]): ) if errors: errors_length = len(errors) - formatted_errors = '\n'.join(errors[:max_num_errors]) + formatted_errors = "\n".join(errors[:max_num_errors]) if errors_length > max_num_errors: logger.warning( f"Found more than {max_num_errors} errors. Showing first {max_num_errors} error messages..." @@ -234,11 +258,13 @@ def _validate_data_rows(cls, objects: List[Dict[str, Any]]): ) @classmethod - def from_name(cls, - client: "labelbox.Client", - parent_id: str, - name: str, - as_json: bool = False): + def from_name( + cls, + client: "labelbox.Client", + parent_id: str, + name: str, + as_json: bool = False, + ): raise NotImplementedError("Inheriting class must override") @property @@ -247,7 +273,6 @@ def parent_id(self) -> str: class CreatableAnnotationImport(AnnotationImport): - @classmethod def create( cls, @@ -256,9 +281,9 @@ def create( name: str, path: Optional[str] = None, url: Optional[str] = None, - labels: Union[List[Dict[str, Any]], List["Label"]] = [] + labels: Union[List[Dict[str, Any]], List["Label"]] = [], ) -> "AnnotationImport": - if (not is_exactly_one_set(url, labels, path)): + if not is_exactly_one_set(url, labels, path): raise ValueError( "Must pass in a nonempty argument for one and only one of the following arguments: url, path, predictions" ) @@ -269,20 +294,25 @@ def create( return cls.create_from_objects(client, id, name, labels) @classmethod - def create_from_url(cls, client: "labelbox.Client", id: str, name: str, - url: str) -> "AnnotationImport": + def create_from_url( + cls, client: "labelbox.Client", id: str, name: str, url: str + ) -> "AnnotationImport": raise NotImplementedError("Inheriting class must override") @classmethod - def create_from_file(cls, client: "labelbox.Client", id: str, name: str, - path: str) -> "AnnotationImport": + def create_from_file( + cls, client: "labelbox.Client", id: str, name: str, path: str + ) -> "AnnotationImport": raise NotImplementedError("Inheriting class must override") @classmethod def create_from_objects( - cls, client: "labelbox.Client", id: str, name: str, - labels: Union[List[Dict[str, Any]], - List["Label"]]) -> "AnnotationImport": + cls, + client: "labelbox.Client", + id: str, + name: str, + labels: Union[List[Dict[str, Any]], List["Label"]], + ) -> "AnnotationImport": raise NotImplementedError("Inheriting class must override") @@ -297,8 +327,9 @@ def parent_id(self) -> str: return self.model_run_id @classmethod - def create_from_file(cls, client: "labelbox.Client", model_run_id: str, - name: str, path: str) -> "MEAPredictionImport": + def create_from_file( + cls, client: "labelbox.Client", model_run_id: str, name: str, path: str + ) -> "MEAPredictionImport": """ Create an MEA prediction import job from a file of annotations @@ -311,17 +342,20 @@ def create_from_file(cls, client: "labelbox.Client", model_run_id: str, MEAPredictionImport """ if os.path.exists(path): - with open(path, 'rb') as f: + with open(path, "rb") as f: return cls._create_mea_import_from_bytes( - client, model_run_id, name, f, - os.stat(path).st_size) + client, model_run_id, name, f, os.stat(path).st_size + ) else: raise ValueError(f"File {path} is not accessible") @classmethod def create_from_objects( - cls, client: "labelbox.Client", model_run_id: str, name, - predictions: Union[List[Dict[str, Any]], List["Label"]] + cls, + client: "labelbox.Client", + model_run_id: str, + name, + predictions: Union[List[Dict[str, Any]], List["Label"]], ) -> "MEAPredictionImport": """ Create an MEA prediction import job from an in memory dictionary @@ -334,14 +368,16 @@ def create_from_objects( Returns: MEAPredictionImport """ - data = cls._get_ndjson_from_objects(predictions, 'annotations') + data = cls._get_ndjson_from_objects(predictions, "annotations") - return cls._create_mea_import_from_bytes(client, model_run_id, name, - data, len(str(data))) + return cls._create_mea_import_from_bytes( + client, model_run_id, name, data, len(str(data)) + ) @classmethod - def create_from_url(cls, client: "labelbox.Client", model_run_id: str, - name: str, url: str) -> "MEAPredictionImport": + def create_from_url( + cls, client: "labelbox.Client", model_run_id: str, name: str, url: str + ) -> "MEAPredictionImport": """ Create an MEA prediction import job from a url The url must point to a file containing prediction annotations. @@ -358,21 +394,26 @@ def create_from_url(cls, client: "labelbox.Client", model_run_id: str, query_str = cls._get_url_mutation() return cls( client, - client.execute(query_str, - params={ - "fileUrl": url, - "modelRunId": model_run_id, - 'name': name - })["createModelErrorAnalysisPredictionImport"]) + client.execute( + query_str, + params={ + "fileUrl": url, + "modelRunId": model_run_id, + "name": name, + }, + )["createModelErrorAnalysisPredictionImport"], + ) else: raise ValueError(f"Url {url} is not reachable") @classmethod - def from_name(cls, - client: "labelbox.Client", - model_run_id: str, - name: str, - as_json: bool = False) -> "MEAPredictionImport": + def from_name( + cls, + client: "labelbox.Client", + model_run_id: str, + name: str, + as_json: bool = False, + ) -> "MEAPredictionImport": """ Retrieves an MEA import job. @@ -395,7 +436,8 @@ def from_name(cls, response = client.execute(query_str, params) if response is None: raise labelbox.exceptions.ResourceNotFoundError( - MEAPredictionImport, params) + MEAPredictionImport, params + ) response = response["modelErrorAnalysisPredictionImport"] if as_json: return response @@ -421,14 +463,19 @@ def _get_file_mutation(cls) -> str: @classmethod def _create_mea_import_from_bytes( - cls, client: "labelbox.Client", model_run_id: str, name: str, - bytes_data: BinaryIO, content_len: int) -> "MEAPredictionImport": + cls, + client: "labelbox.Client", + model_run_id: str, + name: str, + bytes_data: BinaryIO, + content_len: int, + ) -> "MEAPredictionImport": file_name = f"{model_run_id}__{name}.ndjson" variables = { "file": None, "contentLength": content_len, "modelRunId": model_run_id, - "name": name + "name": name, } query_str = cls._get_file_mutation() res = cls._create_from_bytes( @@ -452,10 +499,14 @@ def parent_id(self) -> str: return self.project().uid @classmethod - def create_for_model_run_data_rows(cls, client: "labelbox.Client", - model_run_id: str, - data_row_ids: List[str], project_id: str, - name: str) -> "MEAToMALPredictionImport": + def create_for_model_run_data_rows( + cls, + client: "labelbox.Client", + model_run_id: str, + data_row_ids: List[str], + project_id: str, + name: str, + ) -> "MEAToMALPredictionImport": """ Create an MEA to MAL prediction import job from a list of data row ids of a specific model run @@ -469,20 +520,25 @@ def create_for_model_run_data_rows(cls, client: "labelbox.Client", query_str = cls._get_model_run_data_rows_mutation() return cls( client, - client.execute(query_str, - params={ - "dataRowIds": data_row_ids, - "modelRunId": model_run_id, - "projectId": project_id, - "name": name - })["createMalPredictionImportForModelRunDataRows"]) + client.execute( + query_str, + params={ + "dataRowIds": data_row_ids, + "modelRunId": model_run_id, + "projectId": project_id, + "name": name, + }, + )["createMalPredictionImportForModelRunDataRows"], + ) @classmethod - def from_name(cls, - client: "labelbox.Client", - project_id: str, - name: str, - as_json: bool = False) -> "MEAToMALPredictionImport": + def from_name( + cls, + client: "labelbox.Client", + project_id: str, + name: str, + as_json: bool = False, + ) -> "MEAToMALPredictionImport": """ Retrieves an MEA to MAL import job. @@ -505,7 +561,8 @@ def from_name(cls, response = client.execute(query_str, params) if response is None: raise labelbox.exceptions.ResourceNotFoundError( - MALPredictionImport, params) + MALPredictionImport, params + ) response = response["meaToMalPredictionImport"] if as_json: return response @@ -534,8 +591,9 @@ def parent_id(self) -> str: return self.project().uid @classmethod - def create_from_file(cls, client: "labelbox.Client", project_id: str, - name: str, path: str) -> "MALPredictionImport": + def create_from_file( + cls, client: "labelbox.Client", project_id: str, name: str, path: str + ) -> "MALPredictionImport": """ Create an MAL prediction import job from a file of annotations @@ -548,17 +606,20 @@ def create_from_file(cls, client: "labelbox.Client", project_id: str, MALPredictionImport """ if os.path.exists(path): - with open(path, 'rb') as f: + with open(path, "rb") as f: return cls._create_mal_import_from_bytes( - client, project_id, name, f, - os.stat(path).st_size) + client, project_id, name, f, os.stat(path).st_size + ) else: raise ValueError(f"File {path} is not accessible") @classmethod def create_from_objects( - cls, client: "labelbox.Client", project_id: str, name: str, - predictions: Union[List[Dict[str, Any]], List["Label"]] + cls, + client: "labelbox.Client", + project_id: str, + name: str, + predictions: Union[List[Dict[str, Any]], List["Label"]], ) -> "MALPredictionImport": """ Create an MAL prediction import job from an in memory dictionary @@ -572,22 +633,25 @@ def create_from_objects( MALPredictionImport """ - data = cls._get_ndjson_from_objects(predictions, 'annotations') + data = cls._get_ndjson_from_objects(predictions, "annotations") if len(predictions) > 0 and isinstance(predictions[0], Dict): predictions_dicts = cast(List[Dict[str, Any]], predictions) has_confidence = LabelsConfidencePresenceChecker.check( - predictions_dicts) + predictions_dicts + ) if has_confidence: logger.warning(""" Confidence scores are not supported in MAL Prediction Import. Corresponding confidence score values will be ignored. """) - return cls._create_mal_import_from_bytes(client, project_id, name, data, - len(str(data))) + return cls._create_mal_import_from_bytes( + client, project_id, name, data, len(str(data)) + ) @classmethod - def create_from_url(cls, client: "labelbox.Client", project_id: str, - name: str, url: str) -> "MALPredictionImport": + def create_from_url( + cls, client: "labelbox.Client", project_id: str, name: str, url: str + ) -> "MALPredictionImport": """ Create an MAL prediction import job from a url The url must point to a file containing prediction annotations. @@ -609,17 +673,21 @@ def create_from_url(cls, client: "labelbox.Client", project_id: str, params={ "fileUrl": url, "projectId": project_id, - 'name': name - })["createModelAssistedLabelingPredictionImport"]) + "name": name, + }, + )["createModelAssistedLabelingPredictionImport"], + ) else: raise ValueError(f"Url {url} is not reachable") @classmethod - def from_name(cls, - client: "labelbox.Client", - project_id: str, - name: str, - as_json: bool = False) -> "MALPredictionImport": + def from_name( + cls, + client: "labelbox.Client", + project_id: str, + name: str, + as_json: bool = False, + ) -> "MALPredictionImport": """ Retrieves an MAL import job. @@ -642,7 +710,8 @@ def from_name(cls, response = client.execute(query_str, params) if response is None: raise labelbox.exceptions.ResourceNotFoundError( - MALPredictionImport, params) + MALPredictionImport, params + ) response = response["modelAssistedLabelingPredictionImport"] if as_json: return response @@ -668,18 +737,24 @@ def _get_file_mutation(cls) -> str: @classmethod def _create_mal_import_from_bytes( - cls, client: "labelbox.Client", project_id: str, name: str, - bytes_data: BinaryIO, content_len: int) -> "MALPredictionImport": + cls, + client: "labelbox.Client", + project_id: str, + name: str, + bytes_data: BinaryIO, + content_len: int, + ) -> "MALPredictionImport": file_name = f"{project_id}__{name}.ndjson" variables = { "file": None, "contentLength": content_len, "projectId": project_id, - "name": name + "name": name, } query_str = cls._get_file_mutation() - res = cls._create_from_bytes(client, variables, query_str, file_name, - bytes_data) + res = cls._create_from_bytes( + client, variables, query_str, file_name, bytes_data + ) return cls(client, res["createModelAssistedLabelingPredictionImport"]) @@ -694,8 +769,9 @@ def parent_id(self) -> str: return self.project().uid @classmethod - def create_from_file(cls, client: "labelbox.Client", project_id: str, - name: str, path: str) -> "LabelImport": + def create_from_file( + cls, client: "labelbox.Client", project_id: str, name: str, path: str + ) -> "LabelImport": """ Create a label import job from a file of annotations @@ -708,18 +784,21 @@ def create_from_file(cls, client: "labelbox.Client", project_id: str, LabelImport """ if os.path.exists(path): - with open(path, 'rb') as f: + with open(path, "rb") as f: return cls._create_label_import_from_bytes( - client, project_id, name, f, - os.stat(path).st_size) + client, project_id, name, f, os.stat(path).st_size + ) else: raise ValueError(f"File {path} is not accessible") @classmethod def create_from_objects( - cls, client: "labelbox.Client", project_id: str, name: str, - labels: Union[List[Dict[str, Any]], - List["Label"]]) -> "LabelImport": + cls, + client: "labelbox.Client", + project_id: str, + name: str, + labels: Union[List[Dict[str, Any]], List["Label"]], + ) -> "LabelImport": """ Create a label import job from an in memory dictionary @@ -731,7 +810,7 @@ def create_from_objects( Returns: LabelImport """ - data = cls._get_ndjson_from_objects(labels, 'labels') + data = cls._get_ndjson_from_objects(labels, "labels") if len(labels) > 0 and isinstance(labels[0], Dict): label_dicts = cast(List[Dict[str, Any]], labels) @@ -741,12 +820,14 @@ def create_from_objects( Confidence scores are not supported in Label Import. Corresponding confidence score values will be ignored. """) - return cls._create_label_import_from_bytes(client, project_id, name, - data, len(str(data))) + return cls._create_label_import_from_bytes( + client, project_id, name, data, len(str(data)) + ) @classmethod - def create_from_url(cls, client: "labelbox.Client", project_id: str, - name: str, url: str) -> "LabelImport": + def create_from_url( + cls, client: "labelbox.Client", project_id: str, name: str, url: str + ) -> "LabelImport": """ Create a label annotation import job from a url The url must point to a file containing label annotations. @@ -763,21 +844,26 @@ def create_from_url(cls, client: "labelbox.Client", project_id: str, query_str = cls._get_url_mutation() return cls( client, - client.execute(query_str, - params={ - "fileUrl": url, - "projectId": project_id, - 'name': name - })["createLabelImport"]) + client.execute( + query_str, + params={ + "fileUrl": url, + "projectId": project_id, + "name": name, + }, + )["createLabelImport"], + ) else: raise ValueError(f"Url {url} is not reachable") @classmethod - def from_name(cls, - client: "labelbox.Client", - project_id: str, - name: str, - as_json: bool = False) -> "LabelImport": + def from_name( + cls, + client: "labelbox.Client", + project_id: str, + name: str, + as_json: bool = False, + ) -> "LabelImport": """ Retrieves an label import job. @@ -824,18 +910,23 @@ def _get_file_mutation(cls) -> str: }""" % query.results_query_part(cls) @classmethod - def _create_label_import_from_bytes(cls, client: "labelbox.Client", - project_id: str, name: str, - bytes_data: BinaryIO, - content_len: int) -> "LabelImport": + def _create_label_import_from_bytes( + cls, + client: "labelbox.Client", + project_id: str, + name: str, + bytes_data: BinaryIO, + content_len: int, + ) -> "LabelImport": file_name = f"{project_id}__{name}.ndjson" variables = { "file": None, "contentLength": content_len, "projectId": project_id, - "name": name + "name": name, } query_str = cls._get_file_mutation() - res = cls._create_from_bytes(client, variables, query_str, file_name, - bytes_data) + res = cls._create_from_bytes( + client, variables, query_str, file_name, bytes_data + ) return cls(client, res["createLabelImport"]) diff --git a/libs/labelbox/src/labelbox/schema/asset_attachment.py b/libs/labelbox/src/labelbox/schema/asset_attachment.py index fba542011..0d5598c84 100644 --- a/libs/labelbox/src/labelbox/schema/asset_attachment.py +++ b/libs/labelbox/src/labelbox/schema/asset_attachment.py @@ -7,12 +7,12 @@ class AttachmentType(str, Enum): - @classmethod def __missing__(cls, value: object): if str(value) == "TEXT": warnings.warn( - "The TEXT attachment type is deprecated. Use RAW_TEXT instead.") + "The TEXT attachment type is deprecated. Use RAW_TEXT instead." + ) return cls.RAW_TEXT return value @@ -44,13 +44,13 @@ class AssetAttachment(DbObject): @classmethod def validate_attachment_json(cls, attachment_json: Dict[str, str]) -> None: - for required_key in ['type', 'value']: + for required_key in ["type", "value"]: if required_key not in attachment_json: raise ValueError( f"Must provide a `{required_key}` key for each attachment. Found {attachment_json}." ) - cls.validate_attachment_value(attachment_json['value']) - cls.validate_attachment_type(attachment_json['type']) + cls.validate_attachment_value(attachment_json["value"]) + cls.validate_attachment_type(attachment_json["type"]) @classmethod def validate_attachment_value(cls, attachment_value: str) -> None: @@ -75,10 +75,12 @@ def delete(self) -> None: }""" self.client.execute(query_str, {"attachment_id": self.uid}) - def update(self, - name: Optional[str] = None, - type: Optional[str] = None, - value: Optional[str] = None): + def update( + self, + name: Optional[str] = None, + type: Optional[str] = None, + value: Optional[str] = None, + ): """Updates an attachment on the data row.""" if not name and not type and value is None: raise ValueError( @@ -101,9 +103,10 @@ def update(self, data: {name: $name, type: $type, value: $value} ) { id name type value } }""" - res = (self.client.execute(query_str, - query_params))['updateDataRowAttachment'] + res = (self.client.execute(query_str, query_params))[ + "updateDataRowAttachment" + ] - self.attachment_name = res['name'] - self.attachment_value = res['value'] - self.attachment_type = res['type'] + self.attachment_name = res["name"] + self.attachment_value = res["value"] + self.attachment_type = res["type"] diff --git a/libs/labelbox/src/labelbox/schema/batch.py b/libs/labelbox/src/labelbox/schema/batch.py index 313d02c16..7566a73f6 100644 --- a/libs/labelbox/src/labelbox/schema/batch.py +++ b/libs/labelbox/src/labelbox/schema/batch.py @@ -18,7 +18,7 @@ class Batch(DbObject): - """ A Batch is a group of data rows submitted to a project for labeling + """A Batch is a group of data rows submitted to a project for labeling Attributes: name (str) @@ -30,6 +30,7 @@ class Batch(DbObject): created_by (Relationship): `ToOne` relationship to User """ + name = Field.String("name") created_at = Field.DateTime("created_at") updated_at = Field.DateTime("updated_at") @@ -39,18 +40,15 @@ class Batch(DbObject): # Relationships created_by = Relationship.ToOne("User") - def __init__(self, - client, - project_id, - *args, - failed_data_row_ids=[], - **kwargs): + def __init__( + self, client, project_id, *args, failed_data_row_ids=[], **kwargs + ): super().__init__(client, *args, **kwargs) self.project_id = project_id self._failed_data_row_ids = failed_data_row_ids - def project(self) -> 'Project': # type: ignore - """ Returns Project which this Batch belongs to + def project(self) -> "Project": # type: ignore + """Returns Project which this Batch belongs to Raises: LabelboxError: if the project is not found @@ -69,7 +67,7 @@ def project(self) -> 'Project': # type: ignore return Entity.Project(self.client, response["project"]) def remove_queued_data_rows(self) -> None: - """ Removes remaining queued data rows from the batch and labeling queue. + """Removes remaining queued data rows from the batch and labeling queue. Args: batch (Batch): Batch to remove queued data rows from @@ -80,17 +78,21 @@ def remove_queued_data_rows(self) -> None: self.client.execute( """mutation RemoveQueuedDataRowsFromBatchPyApi($%s: ID!, $%s: ID!) { project(where: {id: $%s}) { removeQueuedDataRowsFromBatch(batchId: $%s) { id } } - }""" % (project_id_param, batch_id_param, project_id_param, - batch_id_param), { - project_id_param: self.project_id, - batch_id_param: self.uid - }, - experimental=True) - - def export_data_rows(self, - timeout_seconds=120, - include_metadata: bool = False) -> Generator: - """ Returns a generator that produces all data rows that are currently + }""" + % ( + project_id_param, + batch_id_param, + project_id_param, + batch_id_param, + ), + {project_id_param: self.project_id, batch_id_param: self.uid}, + experimental=True, + ) + + def export_data_rows( + self, timeout_seconds=120, include_metadata: bool = False + ) -> Generator: + """Returns a generator that produces all data rows that are currently in this batch. Note: For efficiency, the data are cached for 30 minutes. Newly created data rows will not appear @@ -106,7 +108,8 @@ def export_data_rows(self, """ warnings.warn( "You are currently utilizing exports v1 for this action, which will be deprecated after April 30th, 2024. We recommend transitioning to exports v2. To view export v2 details, visit our docs: https://docs.labelbox.com/reference/label-export", - DeprecationWarning) + DeprecationWarning, + ) id_param = "batchId" metadata_param = "includeMetadataInput" @@ -115,10 +118,10 @@ def export_data_rows(self, """ % (id_param, metadata_param, id_param, metadata_param) sleep_time = 2 while True: - res = self.client.execute(query_str, { - id_param: self.uid, - metadata_param: include_metadata - }) + res = self.client.execute( + query_str, + {id_param: self.uid, metadata_param: include_metadata}, + ) res = res["exportBatchDataRows"] if res["status"] == "COMPLETE": download_url = res["downloadUrl"] @@ -126,7 +129,8 @@ def export_data_rows(self, response.raise_for_status() reader = parser.reader(StringIO(response.text)) return ( - Entity.DataRow(self.client, result) for result in reader) + Entity.DataRow(self.client, result) for result in reader + ) elif res["status"] == "FAILED": raise LabelboxError("Data row export failed.") @@ -136,14 +140,15 @@ def export_data_rows(self, f"Unable to export data rows within {timeout_seconds} seconds." ) - logger.debug("Batch '%s' data row export, waiting for server...", - self.uid) + logger.debug( + "Batch '%s' data row export, waiting for server...", self.uid + ) time.sleep(sleep_time) def delete(self) -> None: - """ Deletes the given batch. + """Deletes the given batch. - Note: Batch deletion for batches that has labels is forbidden. + Note: Batch deletion for batches that has labels is forbidden. Args: batch (Batch): Batch to remove queued data rows from @@ -151,17 +156,22 @@ def delete(self) -> None: project_id_param = "projectId" batch_id_param = "batchId" - self.client.execute("""mutation DeleteBatchPyApi($%s: ID!, $%s: ID!) { + self.client.execute( + """mutation DeleteBatchPyApi($%s: ID!, $%s: ID!) { project(where: {id: $%s}) { deleteBatch(batchId: $%s) { deletedBatchId } } - }""" % (project_id_param, batch_id_param, project_id_param, - batch_id_param), { - project_id_param: self.project_id, - batch_id_param: self.uid - }, - experimental=True) + }""" + % ( + project_id_param, + batch_id_param, + project_id_param, + batch_id_param, + ), + {project_id_param: self.project_id, batch_id_param: self.uid}, + experimental=True, + ) def delete_labels(self, set_labels_as_template=False) -> None: - """ Deletes labels that were created for data rows in the batch. + """Deletes labels that were created for data rows in the batch. Args: batch (Batch): Batch to remove queued data rows from @@ -174,17 +184,24 @@ def delete_labels(self, set_labels_as_template=False) -> None: res = self.client.execute( """mutation DeleteBatchLabelsPyApi($%s: ID!, $%s: ID!, $%s: DeleteBatchLabelsType!) { project(where: {id: $%s}) { deleteBatchLabels(batchId: $%s, data:{ type: $%s }) { deletedLabelIds } } - }""" % (project_id_param, batch_id_param, type_param, project_id_param, - batch_id_param, type_param), { - project_id_param: - self.project_id, - batch_id_param: - self.uid, - type_param: - "RequeueDataWithLabelAsTemplate" - if set_labels_as_template else "RequeueData" - }, - experimental=True) + }""" + % ( + project_id_param, + batch_id_param, + type_param, + project_id_param, + batch_id_param, + type_param, + ), + { + project_id_param: self.project_id, + batch_id_param: self.uid, + type_param: "RequeueDataWithLabelAsTemplate" + if set_labels_as_template + else "RequeueData", + }, + experimental=True, + ) return res # modify this function to return an empty list if there are no failed data rows diff --git a/libs/labelbox/src/labelbox/schema/benchmark.py b/libs/labelbox/src/labelbox/schema/benchmark.py index 69cfc2f7f..586530e3c 100644 --- a/libs/labelbox/src/labelbox/schema/benchmark.py +++ b/libs/labelbox/src/labelbox/schema/benchmark.py @@ -3,7 +3,7 @@ class Benchmark(DbObject): - """ Represents a benchmark label. + """Represents a benchmark label. The Benchmarks tool works by interspersing data to be labeled, for which there is a benchmark label, to each person labeling. These @@ -19,6 +19,7 @@ class Benchmark(DbObject): created_by (Relationship): `ToOne` relationship to User reference_label (Relationship): `ToOne` relationship to Label """ + created_at = Field.DateTime("created_at") created_by = Relationship.ToOne("User", False, "created_by") last_activity = Field.DateTime("last_activity") @@ -30,7 +31,10 @@ class Benchmark(DbObject): def delete(self) -> None: label_param = "labelId" query_str = """mutation DeleteBenchmarkPyApi($%s: ID!) { - deleteBenchmark(where: {labelId: $%s}) {id}} """ % (label_param, - label_param) - self.client.execute(query_str, - {label_param: self.reference_label().uid}) + deleteBenchmark(where: {labelId: $%s}) {id}} """ % ( + label_param, + label_param, + ) + self.client.execute( + query_str, {label_param: self.reference_label().uid} + ) diff --git a/libs/labelbox/src/labelbox/schema/bulk_import_request.py b/libs/labelbox/src/labelbox/schema/bulk_import_request.py index 6e65aab58..7caa2c6eb 100644 --- a/libs/labelbox/src/labelbox/schema/bulk_import_request.py +++ b/libs/labelbox/src/labelbox/schema/bulk_import_request.py @@ -8,10 +8,29 @@ from google.api_core import retry from labelbox import parser import requests -from pydantic import ValidationError, BaseModel, Field, field_validator, model_validator, ConfigDict, StringConstraints +from pydantic import ( + ValidationError, + BaseModel, + Field, + field_validator, + model_validator, + ConfigDict, + StringConstraints, +) from typing_extensions import Literal, Annotated -from typing import (Any, List, Optional, BinaryIO, Dict, Iterable, Tuple, Union, - Type, Set, TYPE_CHECKING) +from typing import ( + Any, + List, + Optional, + BinaryIO, + Dict, + Iterable, + Tuple, + Union, + Type, + Set, + TYPE_CHECKING, +) from labelbox import exceptions as lb_exceptions from labelbox import utils @@ -29,11 +48,13 @@ NDJSON_MIME_TYPE = "application/x-ndjson" logger = logging.getLogger(__name__) -#TODO: Deprecate this library in place of labelimport and malprediction import library. +# TODO: Deprecate this library in place of labelimport and malprediction import library. + def _determinants(parent_cls: Any) -> List[str]: return [ - k for k, v in parent_cls.model_fields.items() + k + for k, v in parent_cls.model_fields.items() if v.json_schema_extra and "determinant" in v.json_schema_extra ] @@ -43,8 +64,9 @@ def _make_file_name(project_id: str, name: str) -> str: # TODO(gszpak): move it to client.py -def _make_request_data(project_id: str, name: str, content_length: int, - file_name: str) -> dict: +def _make_request_data( + project_id: str, name: str, content_length: int, file_name: str +) -> dict: query_str = """mutation createBulkImportRequestFromFilePyApi( $projectId: ID!, $name: String!, $file: Upload!, $contentLength: Int!) { createBulkImportRequest(data: { @@ -63,26 +85,30 @@ def _make_request_data(project_id: str, name: str, content_length: int, "projectId": project_id, "name": name, "file": None, - "contentLength": content_length + "contentLength": content_length, } operations = json.dumps({"variables": variables, "query": query_str}) return { "operations": operations, - "map": (None, json.dumps({file_name: ["variables.file"]})) + "map": (None, json.dumps({file_name: ["variables.file"]})), } def _send_create_file_command( - client, request_data: dict, file_name: str, - file_data: Tuple[str, Union[bytes, BinaryIO], str]) -> dict: - + client, + request_data: dict, + file_name: str, + file_data: Tuple[str, Union[bytes, BinaryIO], str], +) -> dict: response = client.execute(data=request_data, files={file_name: file_data}) if not response.get("createBulkImportRequest", None): raise lb_exceptions.LabelboxError( - "Failed to create BulkImportRequest, message: %s" % - response.get("errors", None) or response.get("error", None)) + "Failed to create BulkImportRequest, message: %s" + % response.get("errors", None) + or response.get("error", None) + ) return response @@ -101,6 +127,7 @@ class BulkImportRequest(DbObject): project (Relationship): `ToOne` relationship to Project created_by (Relationship): `ToOne` relationship to User """ + name = lb_Field.String("name") state = lb_Field.Enum(BulkImportRequestState, "state") input_file_url = lb_Field.String("input_file_url") @@ -182,8 +209,7 @@ def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]: return parser.loads(response.text) def refresh(self) -> None: - """Synchronizes values of all fields with the database. - """ + """Synchronizes values of all fields with the database.""" query_str, params = query.get_single(BulkImportRequest, self.uid) res = self.client.execute(query_str, params) res = res[utils.camel_case(BulkImportRequest.type_name())] @@ -207,16 +233,21 @@ def wait_until_done(self, sleep_time_seconds: int = 5) -> None: time.sleep(sleep_time_seconds) self.__exponential_backoff_refresh() - @retry.Retry(predicate=retry.if_exception_type(lb_exceptions.ApiLimitError, - lb_exceptions.TimeoutError, - lb_exceptions.NetworkError)) + @retry.Retry( + predicate=retry.if_exception_type( + lb_exceptions.ApiLimitError, + lb_exceptions.TimeoutError, + lb_exceptions.NetworkError, + ) + ) def __exponential_backoff_refresh(self) -> None: self.refresh() @classmethod - def from_name(cls, client, project_id: str, - name: str) -> 'BulkImportRequest': - """ Fetches existing BulkImportRequest. + def from_name( + cls, client, project_id: str, name: str + ) -> "BulkImportRequest": + """Fetches existing BulkImportRequest. Args: client (Client): a Labelbox client @@ -238,15 +269,12 @@ def from_name(cls, client, project_id: str, """ % query.results_query_part(cls) params = {"projectId": project_id, "name": name} response = client.execute(query_str, params=params) - return cls(client, response['bulkImportRequest']) + return cls(client, response["bulkImportRequest"]) @classmethod - def create_from_url(cls, - client, - project_id: str, - name: str, - url: str, - validate=True) -> 'BulkImportRequest': + def create_from_url( + cls, client, project_id: str, name: str, url: str, validate=True + ) -> "BulkImportRequest": """ Creates a BulkImportRequest from a publicly accessible URL to an ndjson file with predictions. @@ -282,17 +310,19 @@ def create_from_url(cls, """ % query.results_query_part(cls) params = {"projectId": project_id, "name": name, "fileUrl": url} bulk_import_request_response = client.execute(query_str, params=params) - return cls(client, - bulk_import_request_response["createBulkImportRequest"]) + return cls( + client, bulk_import_request_response["createBulkImportRequest"] + ) @classmethod - def create_from_objects(cls, - client, - project_id: str, - name: str, - predictions: Union[Iterable[Dict], - Iterable["Label"]], - validate=True) -> 'BulkImportRequest': + def create_from_objects( + cls, + client, + project_id: str, + name: str, + predictions: Union[Iterable[Dict], Iterable["Label"]], + validate=True, + ) -> "BulkImportRequest": """ Creates a `BulkImportRequest` from an iterable of dictionaries. @@ -332,27 +362,27 @@ def create_from_objects(cls, data_str = parser.dumps(ndjson_predictions) if not data_str: - raise ValueError('annotations cannot be empty') + raise ValueError("annotations cannot be empty") - data = data_str.encode('utf-8') + data = data_str.encode("utf-8") file_name = _make_file_name(project_id, name) - request_data = _make_request_data(project_id, name, len(data_str), - file_name) + request_data = _make_request_data( + project_id, name, len(data_str), file_name + ) file_data = (file_name, data, NDJSON_MIME_TYPE) - response_data = _send_create_file_command(client, - request_data=request_data, - file_name=file_name, - file_data=file_data) + response_data = _send_create_file_command( + client, + request_data=request_data, + file_name=file_name, + file_data=file_data, + ) return cls(client, response_data["createBulkImportRequest"]) @classmethod - def create_from_local_file(cls, - client, - project_id: str, - name: str, - file: Path, - validate_file=True) -> 'BulkImportRequest': + def create_from_local_file( + cls, client, project_id: str, name: str, file: Path, validate_file=True + ) -> "BulkImportRequest": """ Creates a BulkImportRequest from a local ndjson file with predictions. @@ -369,10 +399,11 @@ def create_from_local_file(cls, """ file_name = _make_file_name(project_id, name) content_length = file.stat().st_size - request_data = _make_request_data(project_id, name, content_length, - file_name) + request_data = _make_request_data( + project_id, name, content_length, file_name + ) - with file.open('rb') as f: + with file.open("rb") as f: if validate_file: reader = parser.reader(f) # ensure that the underlying json load call is valid @@ -386,12 +417,13 @@ def create_from_local_file(cls, else: f.seek(0) file_data = (file.name, f, NDJSON_MIME_TYPE) - response_data = _send_create_file_command(client, request_data, - file_name, file_data) + response_data = _send_create_file_command( + client, request_data, file_name, file_data + ) return cls(client, response_data["createBulkImportRequest"]) def delete(self) -> None: - """ Deletes the import job and also any annotations created by this import. + """Deletes the import job and also any annotations created by this import. Returns: None @@ -406,8 +438,9 @@ def delete(self) -> None: self.client.execute(query_str, {id_param: self.uid}) -def _validate_ndjson(lines: Iterable[Dict[str, Any]], - project: "Project") -> None: +def _validate_ndjson( + lines: Iterable[Dict[str, Any]], project: "Project" +) -> None: """ Client side validation of an ndjson object. @@ -426,26 +459,29 @@ def _validate_ndjson(lines: Iterable[Dict[str, Any]], UuidError: Duplicate UUID in upload """ feature_schemas_by_id, feature_schemas_by_name = get_mal_schemas( - project.ontology()) + project.ontology() + ) uids: Set[str] = set() for idx, line in enumerate(lines): try: annotation = NDAnnotation(**line) - annotation.validate_instance(feature_schemas_by_id, - feature_schemas_by_name) + annotation.validate_instance( + feature_schemas_by_id, feature_schemas_by_name + ) uuid = str(annotation.uuid) if uuid in uids: raise lb_exceptions.UuidError( - f'{uuid} already used in this import job, ' - 'must be unique for the project.') + f"{uuid} already used in this import job, " + "must be unique for the project." + ) uids.add(uuid) - except (ValidationError, ValueError, TypeError, - KeyError) as e: + except (ValidationError, ValueError, TypeError, KeyError) as e: raise lb_exceptions.MALValidationError( - f"Invalid NDJson on line {idx}") from e + f"Invalid NDJson on line {idx}" + ) from e -#The rest of this file contains objects for MAL validation +# The rest of this file contains objects for MAL validation def parse_classification(tool): """ Parses a classification from an ontology. Only radio, checklist, and text are supported for mal @@ -456,20 +492,20 @@ def parse_classification(tool): Returns: dict """ - if tool['type'] in ['radio', 'checklist']: - option_schema_ids = [r['featureSchemaId'] for r in tool['options']] - option_names = [r['value'] for r in tool['options']] + if tool["type"] in ["radio", "checklist"]: + option_schema_ids = [r["featureSchemaId"] for r in tool["options"]] + option_names = [r["value"] for r in tool["options"]] return { - 'tool': tool['type'], - 'featureSchemaId': tool['featureSchemaId'], - 'name': tool['name'], - 'options': [*option_schema_ids, *option_names] + "tool": tool["type"], + "featureSchemaId": tool["featureSchemaId"], + "name": tool["name"], + "options": [*option_schema_ids, *option_names], } - elif tool['type'] == 'text': + elif tool["type"] == "text": return { - 'tool': tool['type'], - 'name': tool['name'], - 'featureSchemaId': tool['featureSchemaId'] + "tool": tool["type"], + "name": tool["name"], + "featureSchemaId": tool["featureSchemaId"], } @@ -485,31 +521,32 @@ def get_mal_schemas(ontology): valid_feature_schemas_by_schema_id = {} valid_feature_schemas_by_name = {} - for tool in ontology.normalized['tools']: + for tool in ontology.normalized["tools"]: classifications = [ parse_classification(classification_tool) - for classification_tool in tool['classifications'] + for classification_tool in tool["classifications"] ] classifications_by_schema_id = { - v['featureSchemaId']: v for v in classifications + v["featureSchemaId"]: v for v in classifications } - classifications_by_name = {v['name']: v for v in classifications} - valid_feature_schemas_by_schema_id[tool['featureSchemaId']] = { - 'tool': tool['tool'], - 'classificationsBySchemaId': classifications_by_schema_id, - 'classificationsByName': classifications_by_name, - 'name': tool['name'] + classifications_by_name = {v["name"]: v for v in classifications} + valid_feature_schemas_by_schema_id[tool["featureSchemaId"]] = { + "tool": tool["tool"], + "classificationsBySchemaId": classifications_by_schema_id, + "classificationsByName": classifications_by_name, + "name": tool["name"], } - valid_feature_schemas_by_name[tool['name']] = { - 'tool': tool['tool'], - 'classificationsBySchemaId': classifications_by_schema_id, - 'classificationsByName': classifications_by_name, - 'name': tool['name'] + valid_feature_schemas_by_name[tool["name"]] = { + "tool": tool["tool"], + "classificationsBySchemaId": classifications_by_schema_id, + "classificationsByName": classifications_by_name, + "name": tool["name"], } - for tool in ontology.normalized['classifications']: - valid_feature_schemas_by_schema_id[ - tool['featureSchemaId']] = parse_classification(tool) - valid_feature_schemas_by_name[tool['name']] = parse_classification(tool) + for tool in ontology.normalized["classifications"]: + valid_feature_schemas_by_schema_id[tool["featureSchemaId"]] = ( + parse_classification(tool) + ) + valid_feature_schemas_by_name[tool["name"]] = parse_classification(tool) return valid_feature_schemas_by_schema_id, valid_feature_schemas_by_name @@ -531,13 +568,12 @@ class FrameLocation(BaseModel): class VideoSupported(BaseModel): - #Note that frames are only allowed as top level inferences for video + # Note that frames are only allowed as top level inferences for video frames: Optional[List[FrameLocation]] = None # Base class for a special kind of union. class SpecialUnion: - def __new__(cls, **kwargs): return cls.build(kwargs) @@ -553,7 +589,8 @@ def get_union_types(cls): union_types = [x for x in cls.__orig_bases__ if hasattr(x, "__args__")] if len(union_types) < 1: raise TypeError( - "Class {cls} should inherit from a union of objects to build") + "Class {cls} should inherit from a union of objects to build" + ) if len(union_types) > 1: raise TypeError( f"Class {cls} should inherit from exactly one union of objects to build. Found {union_types}" @@ -561,15 +598,14 @@ def get_union_types(cls): return union_types[0].__args__[0].__args__ @classmethod - def build(cls: Any, data: Union[dict, - BaseModel]) -> "NDBase": + def build(cls: Any, data: Union[dict, BaseModel]) -> "NDBase": """ - Checks through all objects in the union to see which matches the input data. - Args: - data (Union[dict, BaseModel]) : The data for constructing one of the objects in the union - raises: - KeyError: data does not contain the determinant fields for any of the types supported by this SpecialUnion - ValidationError: Error while trying to construct a specific object in the union + Checks through all objects in the union to see which matches the input data. + Args: + data (Union[dict, BaseModel]) : The data for constructing one of the objects in the union + raises: + KeyError: data does not contain the determinant fields for any of the types supported by this SpecialUnion + ValidationError: Error while trying to construct a specific object in the union """ if isinstance(data, BaseModel): @@ -588,11 +624,11 @@ def build(cls: Any, data: Union[dict, matched = type_ if matched is not None: - #These two have the exact same top level keys + # These two have the exact same top level keys if matched in [NDRadio, NDText]: - if isinstance(data['answer'], dict): + if isinstance(data["answer"], dict): matched = NDRadio - elif isinstance(data['answer'], str): + elif isinstance(data["answer"], str): matched = NDText else: raise TypeError( @@ -606,10 +642,10 @@ def build(cls: Any, data: Union[dict, @classmethod def schema(cls): - results = {'definitions': {}} + results = {"definitions": {}} for cl in cls.get_union_types(): schema = cl.schema() - results['definitions'].update(schema.pop('definitions')) + results["definitions"].update(schema.pop("definitions")) results[cl.__name__] = schema return results @@ -626,7 +662,8 @@ class NDFeatureSchema(BaseModel): def most_set_one(self): if self.schemaId is None and self.name is None: raise ValueError( - "Must set either schemaId or name for all feature schemas") + "Must set either schemaId or name for all feature schemas" + ) return self @@ -636,16 +673,19 @@ class NDBase(NDFeatureSchema): dataRow: DataRow model_config = ConfigDict(extra="forbid") - def validate_feature_schemas(self, valid_feature_schemas_by_id, - valid_feature_schemas_by_name): + def validate_feature_schemas( + self, valid_feature_schemas_by_id, valid_feature_schemas_by_name + ): if self.name: if self.name not in valid_feature_schemas_by_name: raise ValueError( f"Name {self.name} is not valid for the provided project's ontology." ) - if self.ontology_type != valid_feature_schemas_by_name[ - self.name]['tool']: + if ( + self.ontology_type + != valid_feature_schemas_by_name[self.name]["tool"] + ): raise ValueError( f"Name {self.name} does not map to the assigned tool {valid_feature_schemas_by_name[self.name]['tool']}" ) @@ -656,16 +696,20 @@ def validate_feature_schemas(self, valid_feature_schemas_by_id, f"Schema id {self.schemaId} is not valid for the provided project's ontology." ) - if self.ontology_type != valid_feature_schemas_by_id[ - self.schemaId]['tool']: + if ( + self.ontology_type + != valid_feature_schemas_by_id[self.schemaId]["tool"] + ): raise ValueError( f"Schema id {self.schemaId} does not map to the assigned tool {valid_feature_schemas_by_id[self.schemaId]['tool']}" ) - def validate_instance(self, valid_feature_schemas_by_id, - valid_feature_schemas_by_name): - self.validate_feature_schemas(valid_feature_schemas_by_id, - valid_feature_schemas_by_name) + def validate_instance( + self, valid_feature_schemas_by_id, valid_feature_schemas_by_name + ): + self.validate_feature_schemas( + valid_feature_schemas_by_id, valid_feature_schemas_by_name + ) ###### Classifications ###### @@ -674,36 +718,42 @@ def validate_instance(self, valid_feature_schemas_by_id, class NDText(NDBase): ontology_type: Literal["text"] = "text" answer: str = Field(json_schema_extra={"determinant": True}) - #No feature schema to check + # No feature schema to check class NDChecklist(VideoSupported, NDBase): ontology_type: Literal["checklist"] = "checklist" - answers: List[NDFeatureSchema] = Field(json_schema_extra={"determinant": True}) + answers: List[NDFeatureSchema] = Field( + json_schema_extra={"determinant": True} + ) - @field_validator('answers', mode="before") + @field_validator("answers", mode="before") def validate_answers(cls, value, field): - #constr not working with mypy. + # constr not working with mypy. if not len(value): raise ValueError("Checklist answers should not be empty") return value - def validate_feature_schemas(self, valid_feature_schemas_by_id, - valid_feature_schemas_by_name): - #Test top level feature schema for this tool - super(NDChecklist, - self).validate_feature_schemas(valid_feature_schemas_by_id, - valid_feature_schemas_by_name) - #Test the feature schemas provided to the answer field - if len(set([answer.name or answer.schemaId for answer in self.answers - ])) != len(self.answers): + def validate_feature_schemas( + self, valid_feature_schemas_by_id, valid_feature_schemas_by_name + ): + # Test top level feature schema for this tool + super(NDChecklist, self).validate_feature_schemas( + valid_feature_schemas_by_id, valid_feature_schemas_by_name + ) + # Test the feature schemas provided to the answer field + if len( + set([answer.name or answer.schemaId for answer in self.answers]) + ) != len(self.answers): raise ValueError( - f"Duplicated featureSchema found for checklist {self.uuid}") + f"Duplicated featureSchema found for checklist {self.uuid}" + ) for answer in self.answers: - options = valid_feature_schemas_by_name[ - self. - name]['options'] if self.name else valid_feature_schemas_by_id[ - self.schemaId]['options'] + options = ( + valid_feature_schemas_by_name[self.name]["options"] + if self.name + else valid_feature_schemas_by_id[self.schemaId]["options"] + ) if answer.name not in options and answer.schemaId not in options: raise ValueError( f"Feature schema provided to {self.ontology_type} invalid. Expected on of {options}. Found {answer}" @@ -714,26 +764,35 @@ class NDRadio(VideoSupported, NDBase): ontology_type: Literal["radio"] = "radio" answer: NDFeatureSchema = Field(json_schema_extra={"determinant": True}) - def validate_feature_schemas(self, valid_feature_schemas_by_id, - valid_feature_schemas_by_name): - super(NDRadio, - self).validate_feature_schemas(valid_feature_schemas_by_id, - valid_feature_schemas_by_name) - options = valid_feature_schemas_by_name[ - self.name]['options'] if self.name else valid_feature_schemas_by_id[ - self.schemaId]['options'] - if self.answer.name not in options and self.answer.schemaId not in options: + def validate_feature_schemas( + self, valid_feature_schemas_by_id, valid_feature_schemas_by_name + ): + super(NDRadio, self).validate_feature_schemas( + valid_feature_schemas_by_id, valid_feature_schemas_by_name + ) + options = ( + valid_feature_schemas_by_name[self.name]["options"] + if self.name + else valid_feature_schemas_by_id[self.schemaId]["options"] + ) + if ( + self.answer.name not in options + and self.answer.schemaId not in options + ): raise ValueError( f"Feature schema provided to {self.ontology_type} invalid. Expected on of {options}. Found {self.answer.name or self.answer.schemaId}" ) -#A union with custom construction logic to improve error messages +# A union with custom construction logic to improve error messages class NDClassification( - SpecialUnion, - Type[Union[ # type: ignore - NDText, NDRadio, NDChecklist]]): - ... + SpecialUnion, + Type[ + Union[ # type: ignore + NDText, NDRadio, NDChecklist + ] + ], +): ... ###### Tools ###### @@ -742,35 +801,41 @@ class NDClassification( class NDBaseTool(NDBase): classifications: List[NDClassification] = [] - #This is indepdent of our problem - def validate_feature_schemas(self, valid_feature_schemas_by_id, - valid_feature_schemas_by_name): - super(NDBaseTool, - self).validate_feature_schemas(valid_feature_schemas_by_id, - valid_feature_schemas_by_name) + # This is indepdent of our problem + def validate_feature_schemas( + self, valid_feature_schemas_by_id, valid_feature_schemas_by_name + ): + super(NDBaseTool, self).validate_feature_schemas( + valid_feature_schemas_by_id, valid_feature_schemas_by_name + ) for classification in self.classifications: classification.validate_feature_schemas( - valid_feature_schemas_by_name[ - self.name]['classificationsBySchemaId'] - if self.name else valid_feature_schemas_by_id[self.schemaId] - ['classificationsBySchemaId'], valid_feature_schemas_by_name[ - self.name]['classificationsByName'] - if self.name else valid_feature_schemas_by_id[ - self.schemaId]['classificationsByName']) - - @field_validator('classifications', mode="before") + valid_feature_schemas_by_name[self.name][ + "classificationsBySchemaId" + ] + if self.name + else valid_feature_schemas_by_id[self.schemaId][ + "classificationsBySchemaId" + ], + valid_feature_schemas_by_name[self.name][ + "classificationsByName" + ] + if self.name + else valid_feature_schemas_by_id[self.schemaId][ + "classificationsByName" + ], + ) + + @field_validator("classifications", mode="before") def validate_subclasses(cls, value, field): - #Create uuid and datarow id so we don't have to define classification objects twice - #This is caused by the fact that we require these ids for top level classifications but not for subclasses + # Create uuid and datarow id so we don't have to define classification objects twice + # This is caused by the fact that we require these ids for top level classifications but not for subclasses results = [] - dummy_id = 'child'.center(25, '_') + dummy_id = "child".center(25, "_") for row in value: - results.append({ - **row, 'dataRow': { - 'id': dummy_id - }, - 'uuid': str(uuid4()) - }) + results.append( + {**row, "dataRow": {"id": dummy_id}, "uuid": str(uuid4())} + ) return results @@ -778,11 +843,12 @@ class NDPolygon(NDBaseTool): ontology_type: Literal["polygon"] = "polygon" polygon: List[Point] = Field(json_schema_extra={"determinant": True}) - @field_validator('polygon') + @field_validator("polygon") def is_geom_valid(cls, v): if len(v) < 3: raise ValueError( - f"A polygon must have at least 3 points to be valid. Found {v}") + f"A polygon must have at least 3 points to be valid. Found {v}" + ) return v @@ -790,24 +856,25 @@ class NDPolyline(NDBaseTool): ontology_type: Literal["line"] = "line" line: List[Point] = Field(json_schema_extra={"determinant": True}) - @field_validator('line') + @field_validator("line") def is_geom_valid(cls, v): if len(v) < 2: raise ValueError( - f"A line must have at least 2 points to be valid. Found {v}") + f"A line must have at least 2 points to be valid. Found {v}" + ) return v class NDRectangle(NDBaseTool): ontology_type: Literal["rectangle"] = "rectangle" bbox: Bbox = Field(json_schema_extra={"determinant": True}) - #Could check if points are positive + # Could check if points are positive class NDPoint(NDBaseTool): ontology_type: Literal["point"] = "point" point: Point = Field(json_schema_extra={"determinant": True}) - #Could check if points are positive + # Could check if points are positive class EntityLocation(BaseModel): @@ -819,17 +886,18 @@ class NDTextEntity(NDBaseTool): ontology_type: Literal["named-entity"] = "named-entity" location: EntityLocation = Field(json_schema_extra={"determinant": True}) - @field_validator('location') + @field_validator("location") def is_valid_location(cls, v): if isinstance(v, BaseModel): v = v.model_dump() if len(v) < 2: raise ValueError( - f"A line must have at least 2 points to be valid. Found {v}") - if v['start'] < 0: + f"A line must have at least 2 points to be valid. Found {v}" + ) + if v["start"] < 0: raise ValueError(f"Text location must be positive. Found {v}") - if v['start'] > v['end']: + if v["start"] > v["end"]: raise ValueError( f"Text start location must be less or equal than end. Found {v}" ) @@ -840,7 +908,7 @@ class RLEMaskFeatures(BaseModel): counts: List[int] size: List[int] - @field_validator('counts') + @field_validator("counts") def validate_counts(cls, counts): if not all([count >= 0 for count in counts]): raise ValueError( @@ -848,7 +916,7 @@ def validate_counts(cls, counts): ) return counts - @field_validator('size') + @field_validator("size") def validate_size(cls, size): if len(size) != 2: raise ValueError( @@ -856,7 +924,8 @@ def validate_size(cls, size): ) if not all([count > 0 for count in size]): raise ValueError( - f"Mask `size` should be a postitive int. Found : {size}") + f"Mask `size` should be a postitive int. Found : {size}" + ) return size @@ -869,9 +938,9 @@ class URIMaskFeatures(BaseModel): instanceURI: str colorRGB: Union[List[int], Tuple[int, int, int]] - @field_validator('colorRGB') + @field_validator("colorRGB") def validate_color(cls, colorRGB): - #Does the dtype matter? Can it be a float? + # Does the dtype matter? Can it be a float? if not isinstance(colorRGB, (tuple, list)): raise ValueError( f"Received color that is not a list or tuple. Found : {colorRGB}" @@ -882,39 +951,46 @@ def validate_color(cls, colorRGB): ) elif not all([0 <= color <= 255 for color in colorRGB]): raise ValueError( - f"All rgb colors must be between 0 and 255. Found : {colorRGB}") + f"All rgb colors must be between 0 and 255. Found : {colorRGB}" + ) return colorRGB class NDMask(NDBaseTool): ontology_type: Literal["superpixel"] = "superpixel" - mask: Union[URIMaskFeatures, PNGMaskFeatures, - RLEMaskFeatures] = Field(json_schema_extra={"determinant": True}) + mask: Union[URIMaskFeatures, PNGMaskFeatures, RLEMaskFeatures] = Field( + json_schema_extra={"determinant": True} + ) -#A union with custom construction logic to improve error messages +# A union with custom construction logic to improve error messages class NDTool( - SpecialUnion, - Type[Union[ # type: ignore + SpecialUnion, + Type[ + Union[ # type: ignore NDMask, NDTextEntity, NDPoint, NDRectangle, NDPolyline, NDPolygon, - ]]): - ... + ] + ], +): ... class NDAnnotation( - SpecialUnion, - Type[Union[ # type: ignore - NDTool, NDClassification]]): - + SpecialUnion, + Type[ + Union[ # type: ignore + NDTool, NDClassification + ] + ], +): @classmethod def build(cls: Any, data) -> "NDBase": if not isinstance(data, dict): - raise ValueError('value must be dict') + raise ValueError("value must be dict") errors = [] for cl in cls.get_union_types(): try: @@ -922,14 +998,15 @@ def build(cls: Any, data) -> "NDBase": except KeyError as e: errors.append(f"{cl.__name__}: {e}") - raise ValueError('Unable to construct any annotation.\n{}'.format( - "\n".join(errors))) + raise ValueError( + "Unable to construct any annotation.\n{}".format("\n".join(errors)) + ) @classmethod def schema(cls): - data = {'definitions': {}} + data = {"definitions": {}} for type_ in cls.get_union_types(): schema_ = type_.schema() - data['definitions'].update(schema_.pop('definitions')) + data["definitions"].update(schema_.pop("definitions")) data[type_.__name__] = schema_ return data diff --git a/libs/labelbox/src/labelbox/schema/catalog.py b/libs/labelbox/src/labelbox/schema/catalog.py index c377703b1..567bbd777 100644 --- a/libs/labelbox/src/labelbox/schema/catalog.py +++ b/libs/labelbox/src/labelbox/schema/catalog.py @@ -2,12 +2,15 @@ from labelbox.orm.db_object import experimental from labelbox.schema.export_filters import CatalogExportFilters, build_filters -from labelbox.schema.export_params import (CatalogExportParams, - validate_catalog_export_params) +from labelbox.schema.export_params import ( + CatalogExportParams, + validate_catalog_export_params, +) from labelbox.schema.export_task import ExportTask from labelbox.schema.task import Task from typing import TYPE_CHECKING + if TYPE_CHECKING: from labelbox import Client @@ -15,7 +18,7 @@ class Catalog: client: "Client" - def __init__(self, client: 'Client'): + def __init__(self, client: "Client"): self.client = client def export_v2( @@ -43,7 +46,7 @@ def export_v2( >>> task.result """ task, is_streamable = self._export(task_name, filters, params) - if (is_streamable): + if is_streamable: return ExportTask(task, True) return task @@ -72,44 +75,49 @@ def export( task, _ = self._export(task_name, filters, params, streamable=True) return ExportTask(task) - def _export(self, - task_name: Optional[str] = None, - filters: Union[CatalogExportFilters, Dict[str, List[str]], - None] = None, - params: Optional[CatalogExportParams] = None, - streamable: bool = False) -> Tuple[Task, bool]: - - _params = params or CatalogExportParams({ - "attachments": False, - "embeddings": False, - "metadata_fields": False, - "data_row_details": False, - "project_details": False, - "performance_details": False, - "label_details": False, - "media_type_override": None, - "model_run_ids": None, - "project_ids": None, - "interpolated_frames": False, - "all_projects": False, - "all_model_runs": False, - }) + def _export( + self, + task_name: Optional[str] = None, + filters: Union[CatalogExportFilters, Dict[str, List[str]], None] = None, + params: Optional[CatalogExportParams] = None, + streamable: bool = False, + ) -> Tuple[Task, bool]: + _params = params or CatalogExportParams( + { + "attachments": False, + "embeddings": False, + "metadata_fields": False, + "data_row_details": False, + "project_details": False, + "performance_details": False, + "label_details": False, + "media_type_override": None, + "model_run_ids": None, + "project_ids": None, + "interpolated_frames": False, + "all_projects": False, + "all_model_runs": False, + } + ) validate_catalog_export_params(_params) - _filters = filters or CatalogExportFilters({ - "last_activity_at": None, - "label_created_at": None, - "data_row_ids": None, - "global_keys": None, - }) + _filters = filters or CatalogExportFilters( + { + "last_activity_at": None, + "label_created_at": None, + "data_row_ids": None, + "global_keys": None, + } + ) mutation_name = "exportDataRowsInCatalog" create_task_query_str = ( f"mutation {mutation_name}PyApi" f"($input: ExportDataRowsInCatalogInput!)" - f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}") + f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}" + ) - media_type_override = _params.get('media_type_override', None) + media_type_override = _params.get("media_type_override", None) query_params: Dict[str, Any] = { "input": { "taskName": task_name, @@ -121,35 +129,30 @@ def _export(self, }, "isStreamableReady": True, "params": { - "mediaTypeOverride": - media_type_override.value - if media_type_override is not None else None, - "includeAttachments": - _params.get('attachments', False), - "includeEmbeddings": - _params.get('embeddings', False), - "includeMetadata": - _params.get('metadata_fields', False), - "includeDataRowDetails": - _params.get('data_row_details', False), - "includeProjectDetails": - _params.get('project_details', False), - "includePerformanceDetails": - _params.get('performance_details', False), - "includeLabelDetails": - _params.get('label_details', False), - "includeInterpolatedFrames": - _params.get('interpolated_frames', False), - "includePredictions": - _params.get('predictions', False), - "projectIds": - _params.get('project_ids', None), - "modelRunIds": - _params.get('model_run_ids', None), - "allProjects": - _params.get('all_projects', False), - "allModelRuns": - _params.get('all_model_runs', False), + "mediaTypeOverride": media_type_override.value + if media_type_override is not None + else None, + "includeAttachments": _params.get("attachments", False), + "includeEmbeddings": _params.get("embeddings", False), + "includeMetadata": _params.get("metadata_fields", False), + "includeDataRowDetails": _params.get( + "data_row_details", False + ), + "includeProjectDetails": _params.get( + "project_details", False + ), + "includePerformanceDetails": _params.get( + "performance_details", False + ), + "includeLabelDetails": _params.get("label_details", False), + "includeInterpolatedFrames": _params.get( + "interpolated_frames", False + ), + "includePredictions": _params.get("predictions", False), + "projectIds": _params.get("project_ids", None), + "modelRunIds": _params.get("model_run_ids", None), + "allProjects": _params.get("all_projects", False), + "allModelRuns": _params.get("all_model_runs", False), }, "streamable": streamable, } @@ -158,9 +161,9 @@ def _export(self, search_query = build_filters(self.client, _filters) query_params["input"]["filters"]["searchQuery"]["query"] = search_query - res = self.client.execute(create_task_query_str, - query_params, - error_log_key="errors") + res = self.client.execute( + create_task_query_str, query_params, error_log_key="errors" + ) res = res[mutation_name] task_id = res["taskId"] is_streamable = res["isStreamable"] diff --git a/libs/labelbox/src/labelbox/schema/confidence_presence_checker.py b/libs/labelbox/src/labelbox/schema/confidence_presence_checker.py index cfdbe0ed3..77d3bfb3f 100644 --- a/libs/labelbox/src/labelbox/schema/confidence_presence_checker.py +++ b/libs/labelbox/src/labelbox/schema/confidence_presence_checker.py @@ -13,8 +13,9 @@ def check(cls, raw_labels: List[Dict[str, Any]]): return len(keys.intersection(set(["confidence"]))) == 1 @classmethod - def _collect_keys_from_list(cls, objects: List[Dict[str, Any]], - keys: Set[str]): + def _collect_keys_from_list( + cls, objects: List[Dict[str, Any]], keys: Set[str] + ): for obj in objects: if isinstance(obj, (list, tuple)): cls._collect_keys_from_list(obj, keys) diff --git a/libs/labelbox/src/labelbox/schema/create_batches_task.py b/libs/labelbox/src/labelbox/schema/create_batches_task.py index eb7b5d150..25ff80917 100644 --- a/libs/labelbox/src/labelbox/schema/create_batches_task.py +++ b/libs/labelbox/src/labelbox/schema/create_batches_task.py @@ -13,9 +13,9 @@ def lru_cache() -> Callable[..., Callable[..., Dict[str, Any]]]: class CreateBatchesTask: - - def __init__(self, client, project_id: str, batch_ids: List[str], - task_ids: List[str]): + def __init__( + self, client, project_id: str, batch_ids: List[str], task_ids: List[str] + ): self.client = client self.project_id = project_id self.batches = batch_ids diff --git a/libs/labelbox/src/labelbox/schema/data_row.py b/libs/labelbox/src/labelbox/schema/data_row.py index b7c9b324d..8987a00f0 100644 --- a/libs/labelbox/src/labelbox/schema/data_row.py +++ b/libs/labelbox/src/labelbox/schema/data_row.py @@ -4,12 +4,24 @@ import json from labelbox.orm import query -from labelbox.orm.db_object import DbObject, Updateable, BulkDeletable, experimental +from labelbox.orm.db_object import ( + DbObject, + Updateable, + BulkDeletable, + experimental, +) from labelbox.orm.model import Entity, Field, Relationship from labelbox.schema.asset_attachment import AttachmentType from labelbox.schema.data_row_metadata import DataRowMetadataField # type: ignore -from labelbox.schema.export_filters import DatarowExportFilters, build_filters, validate_at_least_one_of_data_row_ids_or_global_keys -from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params +from labelbox.schema.export_filters import ( + DatarowExportFilters, + build_filters, + validate_at_least_one_of_data_row_ids_or_global_keys, +) +from labelbox.schema.export_params import ( + CatalogExportParams, + validate_catalog_export_params, +) from labelbox.schema.export_task import ExportTask from labelbox.schema.task import Task @@ -20,16 +32,16 @@ class KeyType(str, Enum): - ID = 'ID' + ID = "ID" """An existing CUID""" - GKEY = 'GKEY' + GKEY = "GKEY" """A Global key, could be existing or non-existing""" - AUTO = 'AUTO' + AUTO = "AUTO" """The key will be auto-generated. Only usable for creates""" class DataRow(DbObject, Updateable, BulkDeletable): - """ Internal Labelbox representation of a single piece of data (e.g. image, video, text). + """Internal Labelbox representation of a single piece of data (e.g. image, video, text). Attributes: external_id (str): User-generated file name or identifier @@ -49,6 +61,7 @@ class DataRow(DbObject, Updateable, BulkDeletable): labels (Relationship): `ToMany` relationship to Label attachments (Relationship) `ToMany` relationship with AssetAttachment """ + external_id = Field.String("external_id") global_key = Field.String("global_key") row_data = Field.String("row_data") @@ -59,11 +72,14 @@ class DataRow(DbObject, Updateable, BulkDeletable): dict, graphql_type="DataRowCustomMetadataUpsertInput!", name="metadata_fields", - result_subquery="metadataFields { schemaId name value kind }") - metadata = Field.List(DataRowMetadataField, - name="metadata", - graphql_name="customMetadata", - result_subquery="customMetadata { schemaId value }") + result_subquery="metadataFields { schemaId name value kind }", + ) + metadata = Field.List( + DataRowMetadataField, + name="metadata", + graphql_name="customMetadata", + result_subquery="customMetadata { schemaId value }", + ) # Relationships dataset = Relationship.ToOne("Dataset") @@ -73,7 +89,8 @@ class DataRow(DbObject, Updateable, BulkDeletable): attachments = Relationship.ToMany("AssetAttachment", False, "attachments") supported_meta_types = supported_attachment_types = set( - AttachmentType.__members__) + AttachmentType.__members__ + ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -95,12 +112,12 @@ def update(self, **kwargs): row_data = kwargs.get("row_data") if isinstance(row_data, dict): - kwargs['row_data'] = json.dumps(row_data) + kwargs["row_data"] = json.dumps(row_data) super().update(**kwargs) @staticmethod def bulk_delete(data_rows) -> None: - """ Deletes all the given DataRows. + """Deletes all the given DataRows. Args: data_rows (list of DataRow): The DataRows to delete. @@ -108,7 +125,7 @@ def bulk_delete(data_rows) -> None: BulkDeletable._bulk_delete(data_rows, True) def get_winning_label_id(self, project_id: str) -> Optional[str]: - """ Retrieves the winning label ID, i.e. the one that was marked as the + """Retrieves the winning label ID, i.e. the one that was marked as the best for a particular data row, in a project's workflow. Args: @@ -121,21 +138,27 @@ def get_winning_label_id(self, project_id: str) -> Optional[str]: labelingActivity(where: { projectId: $%s }) { selectedLabelId } - }} """ % (data_row_id_param, project_id_param, data_row_id_param, - project_id_param) + }} """ % ( + data_row_id_param, + project_id_param, + data_row_id_param, + project_id_param, + ) - res = self.client.execute(query_str, { - data_row_id_param: self.uid, - project_id_param: project_id, - }) + res = self.client.execute( + query_str, + { + data_row_id_param: self.uid, + project_id_param: project_id, + }, + ) return res["dataRow"]["labelingActivity"]["selectedLabelId"] - def create_attachment(self, - attachment_type, - attachment_value, - attachment_name=None) -> "AssetAttachment": - """ Adds an AssetAttachment to a DataRow. + def create_attachment( + self, attachment_type, attachment_value, attachment_name=None + ) -> "AssetAttachment": + """Adds an AssetAttachment to a DataRow. Labelers can view these attachments while labeling. >>> datarow.create_attachment("TEXT", "This is a text message") @@ -151,10 +174,9 @@ def create_attachment(self, ValueError: attachment_type must be one of the supported types. ValueError: attachment_value must be a non-empty string. """ - Entity.AssetAttachment.validate_attachment_json({ - 'type': attachment_type, - 'value': attachment_value - }) + Entity.AssetAttachment.validate_attachment_json( + {"type": attachment_type, "value": attachment_value} + ) attachment_type_param = "type" attachment_value_param = "value" @@ -165,20 +187,29 @@ def create_attachment(self, $%s: AttachmentType!, $%s: String!, $%s: String, $%s: ID!) { createDataRowAttachment(data: { type: $%s value: $%s name: $%s dataRowId: $%s}) {%s}} """ % ( - attachment_type_param, attachment_value_param, - attachment_name_param, data_row_id_param, attachment_type_param, - attachment_value_param, attachment_name_param, data_row_id_param, - query.results_query_part(Entity.AssetAttachment)) + attachment_type_param, + attachment_value_param, + attachment_name_param, + data_row_id_param, + attachment_type_param, + attachment_value_param, + attachment_name_param, + data_row_id_param, + query.results_query_part(Entity.AssetAttachment), + ) res = self.client.execute( - query_str, { + query_str, + { attachment_type_param: attachment_type, attachment_value_param: attachment_value, attachment_name_param: attachment_name, - data_row_id_param: self.uid - }) - return Entity.AssetAttachment(self.client, - res["createDataRowAttachment"]) + data_row_id_param: self.uid, + }, + ) + return Entity.AssetAttachment( + self.client, res["createDataRowAttachment"] + ) @staticmethod def export( @@ -210,12 +241,9 @@ def export( >>> task.wait_till_done() >>> task.result """ - task, _ = DataRow._export(client, - data_rows, - global_keys, - task_name, - params, - streamable=True) + task, _ = DataRow._export( + client, data_rows, global_keys, task_name, params, streamable=True + ) return ExportTask(task) @staticmethod @@ -249,8 +277,9 @@ def export_v2( >>> task.wait_till_done() >>> task.result """ - task, is_streamable = DataRow._export(client, data_rows, global_keys, - task_name, params) + task, is_streamable = DataRow._export( + client, data_rows, global_keys, task_name, params + ) if is_streamable: return ExportTask(task, True) return task @@ -264,21 +293,23 @@ def _export( params: Optional[CatalogExportParams] = None, streamable: bool = False, ) -> Tuple[Task, bool]: - _params = params or CatalogExportParams({ - "attachments": False, - "embeddings": False, - "metadata_fields": False, - "data_row_details": False, - "project_details": False, - "performance_details": False, - "label_details": False, - "media_type_override": None, - "model_run_ids": None, - "project_ids": None, - "interpolated_frames": False, - "all_projects": False, - "all_model_runs": False, - }) + _params = params or CatalogExportParams( + { + "attachments": False, + "embeddings": False, + "metadata_fields": False, + "data_row_details": False, + "project_details": False, + "performance_details": False, + "label_details": False, + "media_type_override": None, + "model_run_ids": None, + "project_ids": None, + "interpolated_frames": False, + "all_projects": False, + "all_model_runs": False, + } + ) validate_catalog_export_params(_params) @@ -286,7 +317,8 @@ def _export( create_task_query_str = ( f"mutation {mutation_name}PyApi" f"($input: ExportDataRowsInCatalogInput!)" - f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}") + f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}" + ) data_row_ids = [] if data_rows is not None: @@ -296,17 +328,25 @@ def _export( elif isinstance(dr, str): data_row_ids.append(dr) - filters = DatarowExportFilters({ - "data_row_ids": data_row_ids, - "global_keys": None, - }) if data_row_ids else DatarowExportFilters({ - "data_row_ids": None, - "global_keys": global_keys, - }) + filters = ( + DatarowExportFilters( + { + "data_row_ids": data_row_ids, + "global_keys": None, + } + ) + if data_row_ids + else DatarowExportFilters( + { + "data_row_ids": None, + "global_keys": global_keys, + } + ) + ) validate_at_least_one_of_data_row_ids_or_global_keys(filters) search_query = build_filters(client, filters) - media_type_override = _params.get('media_type_override', None) + media_type_override = _params.get("media_type_override", None) if task_name is None: task_name = f"Export v2: data rows {len(data_row_ids)}" @@ -314,48 +354,41 @@ def _export( "input": { "taskName": task_name, "filters": { - "searchQuery": { - "scope": None, - "query": search_query - } + "searchQuery": {"scope": None, "query": search_query} }, "isStreamableReady": True, "params": { - "mediaTypeOverride": - media_type_override.value - if media_type_override is not None else None, - "includeAttachments": - _params.get('attachments', False), - "includeEmbeddings": - _params.get('embeddings', False), - "includeMetadata": - _params.get('metadata_fields', False), - "includeDataRowDetails": - _params.get('data_row_details', False), - "includeProjectDetails": - _params.get('project_details', False), - "includePerformanceDetails": - _params.get('performance_details', False), - "includeLabelDetails": - _params.get('label_details', False), - "includeInterpolatedFrames": - _params.get('interpolated_frames', False), - "projectIds": - _params.get('project_ids', None), - "modelRunIds": - _params.get('model_run_ids', None), - "allProjects": - _params.get('all_projects', False), - "allModelRuns": - _params.get('all_model_runs', False), + "mediaTypeOverride": media_type_override.value + if media_type_override is not None + else None, + "includeAttachments": _params.get("attachments", False), + "includeEmbeddings": _params.get("embeddings", False), + "includeMetadata": _params.get("metadata_fields", False), + "includeDataRowDetails": _params.get( + "data_row_details", False + ), + "includeProjectDetails": _params.get( + "project_details", False + ), + "includePerformanceDetails": _params.get( + "performance_details", False + ), + "includeLabelDetails": _params.get("label_details", False), + "includeInterpolatedFrames": _params.get( + "interpolated_frames", False + ), + "projectIds": _params.get("project_ids", None), + "modelRunIds": _params.get("model_run_ids", None), + "allProjects": _params.get("all_projects", False), + "allModelRuns": _params.get("all_model_runs", False), }, - "streamable": streamable + "streamable": streamable, } } - res = client.execute(create_task_query_str, - query_params, - error_log_key="errors") + res = client.execute( + create_task_query_str, query_params, error_log_key="errors" + ) res = res[mutation_name] task_id = res["taskId"] is_streamable = res["isStreamable"] diff --git a/libs/labelbox/src/labelbox/schema/data_row_metadata.py b/libs/labelbox/src/labelbox/schema/data_row_metadata.py index cb02c32f8..288459a89 100644 --- a/libs/labelbox/src/labelbox/schema/data_row_metadata.py +++ b/libs/labelbox/src/labelbox/schema/data_row_metadata.py @@ -5,15 +5,36 @@ from itertools import chain import warnings -from typing import List, Optional, Dict, Union, Callable, Type, Any, Generator, overload +from typing import ( + List, + Optional, + Dict, + Union, + Callable, + Type, + Any, + Generator, + overload, +) from typing_extensions import Annotated from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds from labelbox.schema.identifiable import UniqueId, GlobalKey -from pydantic import BaseModel, Field, StringConstraints, conlist, ConfigDict, model_serializer +from pydantic import ( + BaseModel, + Field, + StringConstraints, + conlist, + ConfigDict, + model_serializer, +) from labelbox.schema.ontology import SchemaId -from labelbox.utils import _CamelCaseMixin, format_iso_datetime, format_iso_from_string +from labelbox.utils import ( + _CamelCaseMixin, + format_iso_datetime, + format_iso_from_string, +) class DataRowMetadataKind(Enum): @@ -28,9 +49,7 @@ class DataRowMetadataKind(Enum): # Metadata schema class DataRowMetadataSchema(BaseModel): uid: SchemaId - name: str = Field(strip_whitespace=True, - min_length=1, - max_length=100) + name: str = Field(strip_whitespace=True, min_length=1, max_length=100) reserved: bool kind: DataRowMetadataKind options: Optional[List["DataRowMetadataSchema"]] = None @@ -39,9 +58,7 @@ class DataRowMetadataSchema(BaseModel): DataRowMetadataSchema.model_rebuild() -Embedding: Type[List[float]] = conlist(float, - min_length=128, - max_length=128) +Embedding: Type[List[float]] = conlist(float, min_length=128, max_length=128) String: Type[str] = Field(max_length=4096) @@ -95,49 +112,53 @@ class _UpsertBatchDataRowMetadata(_CamelCaseMixin): class _DeleteBatchDataRowMetadata(_CamelCaseMixin): data_row_identifier: Union[UniqueId, GlobalKey] schema_ids: List[SchemaId] - + model_config = ConfigDict(arbitrary_types_allowed=True) - + @model_serializer(mode="wrap") def model_serializer(self, handler): res = handler(self) - if 'data_row_identifier' in res.keys(): - key = 'data_row_identifier' - id_type_key = 'id_type' + if "data_row_identifier" in res.keys(): + key = "data_row_identifier" + id_type_key = "id_type" else: - key = 'dataRowIdentifier' - id_type_key = 'idType' + key = "dataRowIdentifier" + id_type_key = "idType" data_row_identifier = res.pop(key) res[key] = { "id": data_row_identifier.key, - id_type_key: data_row_identifier.id_type + id_type_key: data_row_identifier.id_type, } return res -_BatchInputs = Union[List[_UpsertBatchDataRowMetadata], - List[_DeleteBatchDataRowMetadata]] +_BatchInputs = Union[ + List[_UpsertBatchDataRowMetadata], List[_DeleteBatchDataRowMetadata] +] _BatchFunction = Callable[[_BatchInputs], List[DataRowMetadataBatchResponse]] class _UpsertCustomMetadataSchemaEnumOptionInput(_CamelCaseMixin): id: Optional[SchemaId] = None - name: Annotated[str, StringConstraints(strip_whitespace=True, - min_length=1, - max_length=100)] + name: Annotated[ + str, + StringConstraints(strip_whitespace=True, min_length=1, max_length=100), + ] kind: str + class _UpsertCustomMetadataSchemaInput(_CamelCaseMixin): id: Optional[SchemaId] = None - name: Annotated[str, StringConstraints(strip_whitespace=True, - min_length=1, - max_length=100)] + name: Annotated[ + str, + StringConstraints(strip_whitespace=True, min_length=1, max_length=100), + ] kind: str options: Optional[List[_UpsertCustomMetadataSchemaEnumOptionInput]] = None class DataRowMetadataOntology: - """ Ontology for data row metadata + """Ontology for data row metadata Metadata provides additional context for a data rows. Metadata is broken into two classes reserved and custom. Reserved fields are defined by Labelbox and used for creating @@ -148,7 +169,6 @@ class DataRowMetadataOntology: """ def __init__(self, client): - self._client = client self._batch_size = 50 # used for uploads and deletes @@ -165,24 +185,24 @@ def _build_ontology(self): f for f in self.fields if f.reserved ] self.reserved_by_id = self._make_id_index(self.reserved_fields) - self.reserved_by_name: Dict[str, Union[DataRowMetadataSchema, Dict[ - str, DataRowMetadataSchema]]] = self._make_name_index( - self.reserved_fields) - self.reserved_by_name_normalized: Dict[ - str, DataRowMetadataSchema] = self._make_normalized_name_index( - self.reserved_fields) + self.reserved_by_name: Dict[ + str, Union[DataRowMetadataSchema, Dict[str, DataRowMetadataSchema]] + ] = self._make_name_index(self.reserved_fields) + self.reserved_by_name_normalized: Dict[str, DataRowMetadataSchema] = ( + self._make_normalized_name_index(self.reserved_fields) + ) # custom fields self.custom_fields: List[DataRowMetadataSchema] = [ f for f in self.fields if not f.reserved ] self.custom_by_id = self._make_id_index(self.custom_fields) - self.custom_by_name: Dict[str, Union[DataRowMetadataSchema, Dict[ - str, - DataRowMetadataSchema]]] = self._make_name_index(self.custom_fields) - self.custom_by_name_normalized: Dict[ - str, DataRowMetadataSchema] = self._make_normalized_name_index( - self.custom_fields) + self.custom_by_name: Dict[ + str, Union[DataRowMetadataSchema, Dict[str, DataRowMetadataSchema]] + ] = self._make_name_index(self.custom_fields) + self.custom_by_name_normalized: Dict[str, DataRowMetadataSchema] = ( + self._make_normalized_name_index(self.custom_fields) + ) @staticmethod def _lookup_in_index_by_name(reserved_index, custom_index, name): @@ -197,7 +217,7 @@ def _lookup_in_index_by_name(reserved_index, custom_index, name): def get_by_name( self, name: str ) -> Union[DataRowMetadataSchema, Dict[str, DataRowMetadataSchema]]: - """ Get metadata by name + """Get metadata by name >>> mdo.get_by_name(name) @@ -210,23 +230,27 @@ def get_by_name( Raises: KeyError: When provided name is not presented in neither reserved nor custom metadata list """ - return self._lookup_in_index_by_name(self.reserved_by_name, - self.custom_by_name, name) + return self._lookup_in_index_by_name( + self.reserved_by_name, self.custom_by_name, name + ) def _get_by_name_normalized(self, name: str) -> DataRowMetadataSchema: - """ Get metadata by name. For options, it provides the option schema instead of list of - options + """Get metadata by name. For options, it provides the option schema instead of list of + options """ # using `normalized` indices to find options by name as well - return self._lookup_in_index_by_name(self.reserved_by_name_normalized, - self.custom_by_name_normalized, - name) + return self._lookup_in_index_by_name( + self.reserved_by_name_normalized, + self.custom_by_name_normalized, + name, + ) @staticmethod def _make_name_index( - fields: List[DataRowMetadataSchema] - ) -> Dict[str, Union[DataRowMetadataSchema, Dict[str, - DataRowMetadataSchema]]]: + fields: List[DataRowMetadataSchema], + ) -> Dict[ + str, Union[DataRowMetadataSchema, Dict[str, DataRowMetadataSchema]] + ]: index = {} for f in fields: if f.options: @@ -239,7 +263,7 @@ def _make_name_index( @staticmethod def _make_normalized_name_index( - fields: List[DataRowMetadataSchema] + fields: List[DataRowMetadataSchema], ) -> Dict[str, DataRowMetadataSchema]: index = {} for f in fields: @@ -248,7 +272,7 @@ def _make_normalized_name_index( @staticmethod def _make_id_index( - fields: List[DataRowMetadataSchema] + fields: List[DataRowMetadataSchema], ) -> Dict[SchemaId, DataRowMetadataSchema]: index = {} for f in fields: @@ -287,29 +311,26 @@ def _parse_ontology(raw_ontology) -> List[DataRowMetadataSchema]: for option in schema["options"]: option["uid"] = option["id"] options.append( - DataRowMetadataSchema(**{ - **option, - **{ - "parent": schema["uid"] - } - })) + DataRowMetadataSchema( + **{**option, **{"parent": schema["uid"]}} + ) + ) schema["options"] = options fields.append(DataRowMetadataSchema(**schema)) return fields def refresh_ontology(self): - """ Update the `DataRowMetadataOntology` instance with the latest - metadata ontology schemas + """Update the `DataRowMetadataOntology` instance with the latest + metadata ontology schemas """ self._raw_ontology = self._get_ontology() self._build_ontology() - def create_schema(self, - name: str, - kind: DataRowMetadataKind, - options: List[str] = None) -> DataRowMetadataSchema: - """ Create metadata schema + def create_schema( + self, name: str, kind: DataRowMetadataKind, options: List[str] = None + ) -> DataRowMetadataSchema: + """Create metadata schema >>> mdo.create_schema(name, kind, options) @@ -327,8 +348,9 @@ def create_schema(self, if not isinstance(kind, DataRowMetadataKind): raise ValueError(f"kind '{kind}' must be a `DataRowMetadataKind`") - upsert_schema = _UpsertCustomMetadataSchemaInput(name=name, - kind=kind.value) + upsert_schema = _UpsertCustomMetadataSchemaInput( + name=name, kind=kind.value + ) if options: if kind != DataRowMetadataKind.enum: raise ValueError( @@ -336,7 +358,8 @@ def create_schema(self, ) upsert_enum_options = [ _UpsertCustomMetadataSchemaEnumOptionInput( - name=o, kind=DataRowMetadataKind.option.value) + name=o, kind=DataRowMetadataKind.option.value + ) for o in options ] upsert_schema.options = upsert_enum_options @@ -344,7 +367,7 @@ def create_schema(self, return self._upsert_schema(upsert_schema) def update_schema(self, name: str, new_name: str) -> DataRowMetadataSchema: - """ Update metadata schema + """Update metadata schema >>> mdo.update_schema(name, new_name) @@ -359,24 +382,24 @@ def update_schema(self, name: str, new_name: str) -> DataRowMetadataSchema: KeyError: When provided name is not a valid custom metadata """ schema = self._validate_custom_schema_by_name(name) - upsert_schema = _UpsertCustomMetadataSchemaInput(id=schema.uid, - name=new_name, - kind=schema.kind.value) + upsert_schema = _UpsertCustomMetadataSchemaInput( + id=schema.uid, name=new_name, kind=schema.kind.value + ) if schema.options: upsert_enum_options = [ _UpsertCustomMetadataSchemaEnumOptionInput( - id=o.uid, - name=o.name, - kind=DataRowMetadataKind.option.value) + id=o.uid, name=o.name, kind=DataRowMetadataKind.option.value + ) for o in schema.options ] upsert_schema.options = upsert_enum_options return self._upsert_schema(upsert_schema) - def update_enum_option(self, name: str, option: str, - new_option: str) -> DataRowMetadataSchema: - """ Update Enum metadata schema option + def update_enum_option( + self, name: str, option: str, new_option: str + ) -> DataRowMetadataSchema: + """Update Enum metadata schema option >>> mdo.update_enum_option(name, option, new_option) @@ -402,13 +425,14 @@ def update_enum_option(self, name: str, option: str, raise ValueError( f"Enum option '{option}' is not a valid option for Enum '{name}', valid options are: {valid_options}" ) - upsert_schema = _UpsertCustomMetadataSchemaInput(id=schema.uid, - name=schema.name, - kind=schema.kind.value) + upsert_schema = _UpsertCustomMetadataSchemaInput( + id=schema.uid, name=schema.name, kind=schema.kind.value + ) upsert_enum_options = [] for o in schema.options: enum_option = _UpsertCustomMetadataSchemaEnumOptionInput( - id=o.uid, name=o.name, kind=o.kind.value) + id=o.uid, name=o.name, kind=o.kind.value + ) if enum_option.name == option: enum_option.name = new_option upsert_enum_options.append(enum_option) @@ -417,7 +441,7 @@ def update_enum_option(self, name: str, option: str, return self._upsert_schema(upsert_schema) def delete_schema(self, name: str) -> bool: - """ Delete metadata schema + """Delete metadata schema >>> mdo.delete_schema(name) @@ -436,18 +460,17 @@ def delete_schema(self, name: str) -> bool: success } }""" - res = self._client.execute(query, {'where': { - 'id': schema.uid - }})['deleteCustomMetadataSchema'] + res = self._client.execute(query, {"where": {"id": schema.uid}})[ + "deleteCustomMetadataSchema" + ] self.refresh_ontology() - return res['success'] + return res["success"] def parse_metadata( - self, unparsed: List[Dict[str, - List[Union[str, - Dict]]]]) -> List[DataRowMetadata]: - """ Parse metadata responses + self, unparsed: List[Dict[str, List[Union[str, Dict]]]] + ) -> List[DataRowMetadata]: + """Parse metadata responses >>> mdo.parse_metadata([metadata]) @@ -466,15 +489,18 @@ def parse_metadata( if "fields" in dr: fields = self.parse_metadata_fields(dr["fields"]) parsed.append( - DataRowMetadata(data_row_id=dr["dataRowId"], - global_key=dr["globalKey"], - fields=fields)) + DataRowMetadata( + data_row_id=dr["dataRowId"], + global_key=dr["globalKey"], + fields=fields, + ) + ) return parsed def parse_metadata_fields( - self, unparsed: List[Dict[str, - Dict]]) -> List[DataRowMetadataField]: - """ Parse metadata fields as list of `DataRowMetadataField` + self, unparsed: List[Dict[str, Dict]] + ) -> List[DataRowMetadataField]: + """Parse metadata fields as list of `DataRowMetadataField` >>> mdo.parse_metadata_fields([metadata_fields]) @@ -494,31 +520,35 @@ def parse_metadata_fields( self.refresh_ontology() if f["schemaId"] not in self.fields_by_id: raise ValueError( - f"Schema Id `{f['schemaId']}` not found in ontology") + f"Schema Id `{f['schemaId']}` not found in ontology" + ) schema = self.fields_by_id[f["schemaId"]] if schema.kind == DataRowMetadataKind.enum: continue elif schema.kind == DataRowMetadataKind.option: - field = DataRowMetadataField(schema_id=schema.parent, - value=schema.uid) + field = DataRowMetadataField( + schema_id=schema.parent, value=schema.uid + ) elif schema.kind == DataRowMetadataKind.datetime: - field = DataRowMetadataField(schema_id=schema.uid, - value=format_iso_from_string( - f["value"])) + field = DataRowMetadataField( + schema_id=schema.uid, + value=format_iso_from_string(f["value"]), + ) else: - field = DataRowMetadataField(schema_id=schema.uid, - value=f["value"]) + field = DataRowMetadataField( + schema_id=schema.uid, value=f["value"] + ) field.name = schema.name parsed.append(field) return parsed def bulk_upsert( - self, metadata: List[DataRowMetadata] + self, metadata: List[DataRowMetadata] ) -> List[DataRowMetadataBatchResponse]: """Upsert metadata to a list of data rows - + You may specify data row by either data_row_id or global_key >>> metadata = DataRowMetadata( @@ -542,7 +572,7 @@ def bulk_upsert( raise ValueError("Empty list passed") def _batch_upsert( - upserts: List[_UpsertBatchDataRowMetadata] + upserts: List[_UpsertBatchDataRowMetadata], ) -> List[DataRowMetadataBatchResponse]: query = """mutation UpsertDataRowMetadataBetaPyApi($metadata: [DataRowCustomMetadataBatchUpsertInput!]!) { upsertDataRowCustomMetadata(data: $metadata){ @@ -555,14 +585,17 @@ def _batch_upsert( } } }""" - res = self._client.execute( - query, {"metadata": upserts})['upsertDataRowCustomMetadata'] + res = self._client.execute(query, {"metadata": upserts})[ + "upsertDataRowCustomMetadata" + ] return [ - DataRowMetadataBatchResponse(global_key=r['globalKey'], - data_row_id=r['dataRowId'], - error=r['error'], - fields=self.parse_metadata( - [r])[0].fields) for r in res + DataRowMetadataBatchResponse( + global_key=r["globalKey"], + data_row_id=r["dataRowId"], + error=r["error"], + fields=self.parse_metadata([r])[0].fields, + ) + for r in res ] items = [] @@ -574,14 +607,18 @@ def _batch_upsert( fields=list( chain.from_iterable( self._parse_upsert(f, m.data_row_id) - for f in m.fields))).model_dump(by_alias=True)) + for f in m.fields + ) + ), + ).model_dump(by_alias=True) + ) res = _batch_operations(_batch_upsert, items, self._batch_size) return res def bulk_delete( self, deletes: List[DeleteDataRowMetadata] ) -> List[DataRowMetadataBatchResponse]: - """ Delete metadata from a datarow by specifiying the fields you want to remove + """Delete metadata from a datarow by specifiying the fields you want to remove >>> delete = DeleteDataRowMetadata( >>> data_row_id=UniqueId("datarow-id"), @@ -616,7 +653,7 @@ def bulk_delete( Args: deletes: Data row and schema ids to delete - For data row, we support UniqueId, str, and GlobalKey. + For data row, we support UniqueId, str, and GlobalKey. If you pass a str, we will assume it is a UniqueId Do not pass a mix of data row ids and global keys in the same list @@ -633,9 +670,10 @@ def bulk_delete( for i, delete in enumerate(deletes): if isinstance(delete.data_row_id, str): passed_strings = True - deletes[i] = DeleteDataRowMetadata(data_row_id=UniqueId( - delete.data_row_id), - fields=delete.fields) + deletes[i] = DeleteDataRowMetadata( + data_row_id=UniqueId(delete.data_row_id), + fields=delete.fields, + ) elif isinstance(delete.data_row_id, UniqueId): continue elif isinstance(delete.data_row_id, GlobalKey): @@ -648,10 +686,11 @@ def bulk_delete( if passed_strings: warnings.warn( "Using string for data row id will be deprecated. Please use " - "UniqueId instead.") + "UniqueId instead." + ) def _batch_delete( - deletes: List[_DeleteBatchDataRowMetadata] + deletes: List[_DeleteBatchDataRowMetadata], ) -> List[DataRowMetadataBatchResponse]: query = """mutation DeleteDataRowMetadataBetaPyApi($deletes: [DataRowIdentifierCustomMetadataBatchDeleteInput!]) { deleteDataRowCustomMetadata(dataRowIdentifiers: $deletes) { @@ -664,30 +703,32 @@ def _batch_delete( } } """ - res = self._client.execute( - query, {"deletes": deletes})['deleteDataRowCustomMetadata'] + res = self._client.execute(query, {"deletes": deletes})[ + "deleteDataRowCustomMetadata" + ] failures = [] for dr in res: - dr['fields'] = [f['schemaId'] for f in dr['fields']] + dr["fields"] = [f["schemaId"] for f in dr["fields"]] failures.append(DataRowMetadataBatchResponse(**dr)) return failures items = [self._validate_delete(m) for m in deletes] - return _batch_operations(_batch_delete, - items, - batch_size=self._batch_size) + return _batch_operations( + _batch_delete, items, batch_size=self._batch_size + ) @overload def bulk_export(self, data_row_ids: List[str]) -> List[DataRowMetadata]: pass @overload - def bulk_export(self, - data_row_ids: DataRowIdentifiers) -> List[DataRowMetadata]: + def bulk_export( + self, data_row_ids: DataRowIdentifiers + ) -> List[DataRowMetadata]: pass def bulk_export(self, data_row_ids) -> List[DataRowMetadata]: - """ Exports metadata for a list of data rows + """Exports metadata for a list of data rows >>> mdo.bulk_export([data_row.uid for data_row in data_rows]) @@ -704,15 +745,20 @@ def bulk_export(self, data_row_ids) -> List[DataRowMetadata]: if not len(data_row_ids): raise ValueError("Empty list passed") - if isinstance(data_row_ids, - list) and len(data_row_ids) > 0 and isinstance( - data_row_ids[0], str): + if ( + isinstance(data_row_ids, list) + and len(data_row_ids) > 0 + and isinstance(data_row_ids[0], str) + ): data_row_ids = UniqueIds(data_row_ids) - warnings.warn("Using data row ids will be deprecated. Please use " - "UniqueIds or GlobalKeys instead.") + warnings.warn( + "Using data row ids will be deprecated. Please use " + "UniqueIds or GlobalKeys instead." + ) def _bulk_export( - _data_row_ids: DataRowIdentifiers) -> List[DataRowMetadata]: + _data_row_ids: DataRowIdentifiers, + ) -> List[DataRowMetadata]: query = """query dataRowCustomMetadataPyApi($dataRowIdentifiers: DataRowCustomMetadataDataRowIdentifiersInput) { dataRowCustomMetadata(where: {dataRowIdentifiers : $dataRowIdentifiers}) { dataRowId @@ -726,19 +772,22 @@ def _bulk_export( """ return self.parse_metadata( self._client.execute( - query, { + query, + { "dataRowIdentifiers": { "ids": [id for id in _data_row_ids], - "idType": _data_row_ids.id_type + "idType": _data_row_ids.id_type, } - })['dataRowCustomMetadata']) + }, + )["dataRowCustomMetadata"] + ) - return _batch_operations(_bulk_export, - data_row_ids, - batch_size=self._batch_size) + return _batch_operations( + _bulk_export, data_row_ids, batch_size=self._batch_size + ) def parse_upsert_metadata(self, metadata_fields) -> List[Dict[str, Any]]: - """ Converts either `DataRowMetadataField` or a dictionary representation + """Converts either `DataRowMetadataField` or a dictionary representation of `DataRowMetadataField` into a validated, flattened dictionary of metadata fields that are used to create data row metadata. Used internally in `Dataset.create_data_rows()` @@ -758,14 +807,18 @@ def _convert_metadata_field(metadata_field): raise ValueError( f"Custom metadata field '{metadata_field}' must have a 'value' key" ) - if not "schema_id" in metadata_field and not "name" in metadata_field: + if ( + not "schema_id" in metadata_field + and not "name" in metadata_field + ): raise ValueError( f"Custom metadata field '{metadata_field}' must have either 'schema_id' or 'name' key" ) return DataRowMetadataField( schema_id=metadata_field.get("schema_id"), name=metadata_field.get("name"), - value=metadata_field["value"]) + value=metadata_field["value"], + ) else: raise ValueError( f"Metadata field '{metadata_field}' is neither 'DataRowMetadataField' type or a dictionary" @@ -774,7 +827,8 @@ def _convert_metadata_field(metadata_field): # Convert all metadata fields to DataRowMetadataField type metadata_fields = [_convert_metadata_field(m) for m in metadata_fields] parsed_metadata = list( - chain.from_iterable(self._parse_upsert(m) for m in metadata_fields)) + chain.from_iterable(self._parse_upsert(m) for m in metadata_fields) + ) return [m.model_dump(by_alias=True) for m in parsed_metadata] def _upsert_schema( @@ -793,8 +847,8 @@ def _upsert_schema( } }""" res = self._client.execute( - query, {"data": upsert_schema.model_dump(exclude_none=True) - })['upsertCustomMetadataSchema'] + query, {"data": upsert_schema.model_dump(exclude_none=True)} + )["upsertCustomMetadataSchema"] self.refresh_ontology() return _parse_metadata_schema(res) @@ -822,9 +876,7 @@ def _load_schema_id_by_name(self, metadatum: DataRowMetadataField): self._load_option_by_name(metadatum) def _parse_upsert( - self, - metadatum: DataRowMetadataField, - data_row_id: Optional[str] = None + self, metadatum: DataRowMetadataField, data_row_id: Optional[str] = None ) -> List[_UpsertDataRowMetadataInput]: """Format for metadata upserts to GQL""" @@ -835,7 +887,8 @@ def _parse_upsert( self.refresh_ontology() if metadatum.schema_id not in self.fields_by_id: raise ValueError( - f"Schema Id `{metadatum.schema_id}` not found in ontology") + f"Schema Id `{metadatum.schema_id}` not found in ontology" + ) schema = self.fields_by_id[metadatum.schema_id] try: @@ -851,7 +904,8 @@ def _parse_upsert( parsed = _validate_enum_parse(schema, metadatum) elif schema.kind == DataRowMetadataKind.option: raise ValueError( - "An Option id should not be set as the Schema id") + "An Option id should not be set as the Schema id" + ) else: raise ValueError(f"Unknown type: {schema}") except ValueError as e: @@ -872,7 +926,8 @@ def _validate_delete(self, delete: DeleteDataRowMetadata): self.refresh_ontology() if schema_id not in self.fields_by_id: raise ValueError( - f"Schema Id `{schema_id}` not found in ontology") + f"Schema Id `{schema_id}` not found in ontology" + ) schema = self.fields_by_id[schema_id] # handle users specifying enums by adding all option enums @@ -883,10 +938,12 @@ def _validate_delete(self, delete: DeleteDataRowMetadata): return _DeleteBatchDataRowMetadata( data_row_identifier=delete.data_row_id, - schema_ids=list(delete.fields)).model_dump(by_alias=True) + schema_ids=list(delete.fields), + ).model_dump(by_alias=True) - def _validate_custom_schema_by_name(self, - name: str) -> DataRowMetadataSchema: + def _validate_custom_schema_by_name( + self, name: str + ) -> DataRowMetadataSchema: if name not in self.custom_by_name_normalized: # Fetch latest metadata ontology if metadata can't be found self.refresh_ontology() @@ -899,7 +956,7 @@ def _validate_custom_schema_by_name(self, def _batch_items(iterable: List[Any], size: int) -> Generator[Any, None, None]: l = len(iterable) for ndx in range(0, l, size): - yield iterable[ndx:min(ndx + size, l)] + yield iterable[ndx : min(ndx + size, l)] def _batch_operations( @@ -915,9 +972,8 @@ def _batch_operations( def _validate_parse_embedding( - field: DataRowMetadataField + field: DataRowMetadataField, ) -> List[Dict[str, Union[SchemaId, Embedding]]]: - if isinstance(field.value, list): if not (Embedding.min_items <= len(field.value) <= Embedding.max_items): raise ValueError( @@ -928,19 +984,21 @@ def _validate_parse_embedding( field.value = [float(x) for x in field.value] else: raise ValueError( - f"Expected a list for embedding. Found {type(field.value)}") + f"Expected a list for embedding. Found {type(field.value)}" + ) return [field.model_dump(by_alias=True)] def _validate_parse_number( - field: DataRowMetadataField + field: DataRowMetadataField, ) -> List[Dict[str, Union[SchemaId, str, float, int]]]: field.value = float(field.value) return [field.model_dump(by_alias=True)] def _validate_parse_datetime( - field: DataRowMetadataField) -> List[Dict[str, Union[SchemaId, str]]]: + field: DataRowMetadataField, +) -> List[Dict[str, Union[SchemaId, str]]]: if isinstance(field.value, str): field.value = format_iso_from_string(field.value) elif not isinstance(field.value, datetime): @@ -948,57 +1006,58 @@ def _validate_parse_datetime( f"Value for datetime fields must be either a string or datetime object. Found {type(field.value)}" ) - return [{ - "schemaId": field.schema_id, - "value": format_iso_datetime(field.value) - }] + return [ + {"schemaId": field.schema_id, "value": format_iso_datetime(field.value)} + ] def _validate_parse_text( - field: DataRowMetadataField) -> List[Dict[str, Union[SchemaId, str]]]: + field: DataRowMetadataField, +) -> List[Dict[str, Union[SchemaId, str]]]: if not isinstance(field.value, str): raise ValueError( f"Expected a string type for the text field. Found {type(field.value)}" ) if len(field.value) > String.metadata[0].max_length: raise ValueError( - f"String fields cannot exceed {String.metadata.max_length} characters.") + f"String fields cannot exceed {String.metadata.max_length} characters." + ) return [field.model_dump(by_alias=True)] def _validate_enum_parse( - schema: DataRowMetadataSchema, - field: DataRowMetadataField) -> List[Dict[str, Union[SchemaId, dict]]]: + schema: DataRowMetadataSchema, field: DataRowMetadataField +) -> List[Dict[str, Union[SchemaId, dict]]]: if schema.options: if field.value not in {o.uid for o in schema.options}: raise ValueError( - f"Option `{field.value}` not found for {field.schema_id}") + f"Option `{field.value}` not found for {field.schema_id}" + ) else: raise ValueError("Incorrectly specified enum schema") - return [{ - "schemaId": field.schema_id, - "value": {} - }, { - "schemaId": field.value, - "value": {} - }] + return [ + {"schemaId": field.schema_id, "value": {}}, + {"schemaId": field.value, "value": {}}, + ] def _parse_metadata_schema( - unparsed: Dict[str, Union[str, List]]) -> DataRowMetadataSchema: - uid = unparsed['id'] - name = unparsed['name'] - kind = DataRowMetadataKind(unparsed['kind']) + unparsed: Dict[str, Union[str, List]], +) -> DataRowMetadataSchema: + uid = unparsed["id"] + name = unparsed["name"] + kind = DataRowMetadataKind(unparsed["kind"]) options = [ - DataRowMetadataSchema(uid=o['id'], - name=o['name'], - reserved=False, - kind=DataRowMetadataKind.option, - parent=uid) for o in unparsed['options'] + DataRowMetadataSchema( + uid=o["id"], + name=o["name"], + reserved=False, + kind=DataRowMetadataKind.option, + parent=uid, + ) + for o in unparsed["options"] ] - return DataRowMetadataSchema(uid=uid, - name=name, - reserved=False, - kind=kind, - options=options or None) + return DataRowMetadataSchema( + uid=uid, name=name, reserved=False, kind=kind, options=options or None + ) diff --git a/libs/labelbox/src/labelbox/schema/dataset.py b/libs/labelbox/src/labelbox/schema/dataset.py index eaa37c5b7..17a3afc3d 100644 --- a/libs/labelbox/src/labelbox/schema/dataset.py +++ b/libs/labelbox/src/labelbox/schema/dataset.py @@ -15,7 +15,12 @@ from io import StringIO import requests -from labelbox.exceptions import InvalidQueryError, LabelboxError, ResourceNotFoundError, ResourceCreationError +from labelbox.exceptions import ( + InvalidQueryError, + LabelboxError, + ResourceNotFoundError, + ResourceCreationError, +) from labelbox.orm.comparison import Comparison from labelbox.orm.db_object import DbObject, Updateable, Deletable, experimental from labelbox.orm.model import Entity, Field, Relationship @@ -25,25 +30,34 @@ from labelbox.schema.data_row import DataRow from labelbox.schema.embedding import EmbeddingVector from labelbox.schema.export_filters import DatasetExportFilters, build_filters -from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params +from labelbox.schema.export_params import ( + CatalogExportParams, + validate_catalog_export_params, +) from labelbox.schema.export_task import ExportTask from labelbox.schema.identifiable import UniqueId, GlobalKey from labelbox.schema.task import Task, DataUpsertTask from labelbox.schema.user import User from labelbox.schema.iam_integration import IAMIntegration -from labelbox.schema.internal.data_row_upsert_item import (DataRowItemBase, - DataRowUpsertItem, - DataRowCreateItem) +from labelbox.schema.internal.data_row_upsert_item import ( + DataRowItemBase, + DataRowUpsertItem, + DataRowCreateItem, +) import labelbox.schema.internal.data_row_uploader as data_row_uploader -from labelbox.schema.internal.descriptor_file_creator import DescriptorFileCreator +from labelbox.schema.internal.descriptor_file_creator import ( + DescriptorFileCreator, +) from labelbox.schema.internal.datarow_upload_constants import ( - FILE_UPLOAD_THREAD_COUNT, UPSERT_CHUNK_SIZE_BYTES) + FILE_UPLOAD_THREAD_COUNT, + UPSERT_CHUNK_SIZE_BYTES, +) logger = logging.getLogger(__name__) class Dataset(DbObject, Updateable, Deletable): - """ A Dataset is a collection of DataRows. + """A Dataset is a collection of DataRows. Attributes: name (str) @@ -65,8 +79,9 @@ class Dataset(DbObject, Updateable, Deletable): # Relationships created_by = Relationship.ToOne("User", False, "created_by") organization = Relationship.ToOne("Organization", False) - iam_integration = Relationship.ToOne("IAMIntegration", False, - "iam_integration", "signer") + iam_integration = Relationship.ToOne( + "IAMIntegration", False, "iam_integration", "signer" + ) def data_rows( self, @@ -90,8 +105,11 @@ def data_rows( """ page_size = 500 # hardcode to avoid overloading the server - where_param = query.where_as_dict(Entity.DataRow, - where) if where is not None else None + where_param = ( + query.where_as_dict(Entity.DataRow, where) + if where is not None + else None + ) template = Template( """query DatasetDataRowsPyApi($$id: ID!, $$from: ID, $$first: Int, $$where: DatasetDataRowWhereInput) { @@ -101,28 +119,30 @@ def data_rows( pageInfo { hasNextPage startCursor } } } - """) + """ + ) query_str = template.substitute( - datarow_selections=query.results_query_part(Entity.DataRow)) + datarow_selections=query.results_query_part(Entity.DataRow) + ) params = { - 'id': self.uid, - 'from': from_cursor, - 'first': page_size, - 'where': where_param, + "id": self.uid, + "from": from_cursor, + "first": page_size, + "where": where_param, } return PaginatedCollection( client=self.client, query=query_str, params=params, - dereferencing=['datasetDataRows', 'nodes'], + dereferencing=["datasetDataRows", "nodes"], obj_class=Entity.DataRow, - cursor_path=['datasetDataRows', 'pageInfo', 'startCursor'], + cursor_path=["datasetDataRows", "pageInfo", "startCursor"], ) def create_data_row(self, items=None, **kwargs) -> "DataRow": - """ Creates a single DataRow belonging to this dataset. + """Creates a single DataRow belonging to this dataset. >>> dataset.create_data_row(row_data="http://my_site.com/photos/img_01.jpg") Args: @@ -148,7 +168,8 @@ def create_data_row(self, items=None, **kwargs) -> "DataRow": file_upload_thread_count = 1 completed_task = self._create_data_rows_sync( - [args], file_upload_thread_count=file_upload_thread_count) + [args], file_upload_thread_count=file_upload_thread_count + ) res = completed_task.result if res is None or len(res) == 0: @@ -156,13 +177,12 @@ def create_data_row(self, items=None, **kwargs) -> "DataRow": f"Data row upload did not complete, task status {completed_task.status} task id {completed_task.uid}" ) - return self.client.get_data_row(res[0]['id']) + return self.client.get_data_row(res[0]["id"]) def create_data_rows_sync( - self, - items, - file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT) -> None: - """ Synchronously bulk upload data rows. + self, items, file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT + ) -> None: + """Synchronously bulk upload data rows. Use this instead of `Dataset.create_data_rows` for smaller batches of data rows that need to be uploaded quickly. Cannot use this for uploads containing more than 1000 data rows. @@ -184,17 +204,18 @@ def create_data_rows_sync( """ warnings.warn( "This method is deprecated and will be " - "removed in a future release. Please use create_data_rows instead.") + "removed in a future release. Please use create_data_rows instead." + ) self._create_data_rows_sync( - items, file_upload_thread_count=file_upload_thread_count) + items, file_upload_thread_count=file_upload_thread_count + ) return None # Return None if no exception is raised - def _create_data_rows_sync(self, - items, - file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT - ) -> "DataUpsertTask": + def _create_data_rows_sync( + self, items, file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT + ) -> "DataUpsertTask": max_data_rows_supported = 1000 if len(items) > max_data_rows_supported: raise ValueError( @@ -203,15 +224,18 @@ def _create_data_rows_sync(self, ) if file_upload_thread_count < 1: raise ValueError( - "file_upload_thread_count must be a positive integer") + "file_upload_thread_count must be a positive integer" + ) - task: DataUpsertTask = self.create_data_rows(items, - file_upload_thread_count) + task: DataUpsertTask = self.create_data_rows( + items, file_upload_thread_count + ) task.wait_till_done() if task.has_errors(): raise ResourceCreationError( - f"Data row upload errors: {task.errors}", cause=task.uid) + f"Data row upload errors: {task.errors}", cause=task.uid + ) if task.status != "COMPLETE": raise ResourceCreationError( f"Data row upload did not complete, task status {task.status} task id {task.uid}" @@ -219,11 +243,10 @@ def _create_data_rows_sync(self, return task - def create_data_rows(self, - items, - file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT - ) -> "DataUpsertTask": - """ Asynchronously bulk upload data rows + def create_data_rows( + self, items, file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT + ) -> "DataUpsertTask": + """Asynchronously bulk upload data rows Use this instead of `Dataset.create_data_rows_sync` uploads for batches that contain more than 1000 data rows. @@ -249,7 +272,8 @@ def create_data_rows(self, if file_upload_thread_count < 1: raise ValueError( - "file_upload_thread_count must be a positive integer") + "file_upload_thread_count must be a positive integer" + ) # Usage example upload_items = self._separate_and_process_items(items) @@ -265,14 +289,15 @@ def _separate_and_process_items(self, items): return dict_items + dict_string_items def _build_from_local_paths( - self, - items: List[str], - file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT) -> List[dict]: + self, + items: List[str], + file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT, + ) -> List[dict]: uploaded_items = [] def upload_file(item): item_url = self.client.upload_file(item) - return {'row_data': item_url, 'external_id': item} + return {"row_data": item_url, "external_id": item} with ThreadPoolExecutor(file_upload_thread_count) as executor: futures = [ @@ -285,10 +310,10 @@ def upload_file(item): return uploaded_items - def data_rows_for_external_id(self, - external_id, - limit=10) -> List["DataRow"]: - """ Convenience method for getting a multiple `DataRow` belonging to this + def data_rows_for_external_id( + self, external_id, limit=10 + ) -> List["DataRow"]: + """Convenience method for getting a multiple `DataRow` belonging to this `Dataset` that has the given `external_id`. Args: @@ -315,7 +340,7 @@ def data_rows_for_external_id(self, return at_most_data_rows def data_row_for_external_id(self, external_id) -> "DataRow": - """ Convenience method for getting a single `DataRow` belonging to this + """Convenience method for getting a single `DataRow` belonging to this `Dataset` that has the given `external_id`. Args: @@ -329,18 +354,20 @@ def data_row_for_external_id(self, external_id) -> "DataRow": in this `DataSet` with the given external ID, or if there are multiple `DataRows` for it. """ - data_rows = self.data_rows_for_external_id(external_id=external_id, - limit=2) + data_rows = self.data_rows_for_external_id( + external_id=external_id, limit=2 + ) if len(data_rows) > 1: logger.warning( f"More than one data_row has the provided external_id : `%s`. Use function data_rows_for_external_id to fetch all", - external_id) + external_id, + ) return data_rows[0] - def export_data_rows(self, - timeout_seconds=120, - include_metadata: bool = False) -> Generator: - """ Returns a generator that produces all data rows that are currently + def export_data_rows( + self, timeout_seconds=120, include_metadata: bool = False + ) -> Generator: + """Returns a generator that produces all data rows that are currently attached to this dataset. Note: For efficiency, the data are cached for 30 minutes. Newly created data rows will not appear @@ -356,7 +383,8 @@ def export_data_rows(self, """ warnings.warn( "You are currently utilizing exports v1 for this action, which will be deprecated after April 30th, 2024. We recommend transitioning to exports v2. To view export v2 details, visit our docs: https://docs.labelbox.com/reference/label-export", - DeprecationWarning) + DeprecationWarning, + ) id_param = "datasetId" metadata_param = "includeMetadataInput" query_str = """mutation GetDatasetDataRowsExportUrlPyApi($%s: ID!, $%s: Boolean!) @@ -364,10 +392,10 @@ def export_data_rows(self, """ % (id_param, metadata_param, id_param, metadata_param) sleep_time = 2 while True: - res = self.client.execute(query_str, { - id_param: self.uid, - metadata_param: include_metadata - }) + res = self.client.execute( + query_str, + {id_param: self.uid, metadata_param: include_metadata}, + ) res = res["exportDatasetDataRows"] if res["status"] == "COMPLETE": download_url = res["downloadUrl"] @@ -375,7 +403,8 @@ def export_data_rows(self, response.raise_for_status() reader = parser.reader(StringIO(response.text)) return ( - Entity.DataRow(self.client, result) for result in reader) + Entity.DataRow(self.client, result) for result in reader + ) elif res["status"] == "FAILED": raise LabelboxError("Data row export failed.") @@ -385,8 +414,9 @@ def export_data_rows(self, f"Unable to export data rows within {timeout_seconds} seconds." ) - logger.debug("Dataset '%s' data row export, waiting for server...", - self.uid) + logger.debug( + "Dataset '%s' data row export, waiting for server...", self.uid + ) time.sleep(sleep_time) def export( @@ -439,7 +469,7 @@ def export_v2( >>> task.result """ task, is_streamable = self._export(task_name, filters, params) - if (is_streamable): + if is_streamable: return ExportTask(task, True) return task @@ -450,36 +480,41 @@ def _export( params: Optional[CatalogExportParams] = None, streamable: bool = False, ) -> Tuple[Task, bool]: - _params = params or CatalogExportParams({ - "attachments": False, - "embeddings": False, - "metadata_fields": False, - "data_row_details": False, - "project_details": False, - "performance_details": False, - "label_details": False, - "media_type_override": None, - "model_run_ids": None, - "project_ids": None, - "interpolated_frames": False, - "all_projects": False, - "all_model_runs": False, - }) + _params = params or CatalogExportParams( + { + "attachments": False, + "embeddings": False, + "metadata_fields": False, + "data_row_details": False, + "project_details": False, + "performance_details": False, + "label_details": False, + "media_type_override": None, + "model_run_ids": None, + "project_ids": None, + "interpolated_frames": False, + "all_projects": False, + "all_model_runs": False, + } + ) validate_catalog_export_params(_params) - _filters = filters or DatasetExportFilters({ - "last_activity_at": None, - "label_created_at": None, - "data_row_ids": None, - "global_keys": None, - }) + _filters = filters or DatasetExportFilters( + { + "last_activity_at": None, + "label_created_at": None, + "data_row_ids": None, + "global_keys": None, + } + ) mutation_name = "exportDataRowsInCatalog" create_task_query_str = ( f"mutation {mutation_name}PyApi" f"($input: ExportDataRowsInCatalogInput!)" - f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}") - media_type_override = _params.get('media_type_override', None) + f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}" + ) + media_type_override = _params.get("media_type_override", None) if task_name is None: task_name = f"Export v2: dataset - {self.name}" @@ -494,61 +529,53 @@ def _export( }, "isStreamableReady": True, "params": { - "mediaTypeOverride": - media_type_override.value - if media_type_override is not None else None, - "includeAttachments": - _params.get('attachments', False), - "includeEmbeddings": - _params.get('embeddings', False), - "includeMetadata": - _params.get('metadata_fields', False), - "includeDataRowDetails": - _params.get('data_row_details', False), - "includeProjectDetails": - _params.get('project_details', False), - "includePerformanceDetails": - _params.get('performance_details', False), - "includeLabelDetails": - _params.get('label_details', False), - "includeInterpolatedFrames": - _params.get('interpolated_frames', False), - "includePredictions": - _params.get('predictions', False), - "projectIds": - _params.get('project_ids', None), - "modelRunIds": - _params.get('model_run_ids', None), - "allProjects": - _params.get('all_projects', False), - "allModelRuns": - _params.get('all_model_runs', False), + "mediaTypeOverride": media_type_override.value + if media_type_override is not None + else None, + "includeAttachments": _params.get("attachments", False), + "includeEmbeddings": _params.get("embeddings", False), + "includeMetadata": _params.get("metadata_fields", False), + "includeDataRowDetails": _params.get( + "data_row_details", False + ), + "includeProjectDetails": _params.get( + "project_details", False + ), + "includePerformanceDetails": _params.get( + "performance_details", False + ), + "includeLabelDetails": _params.get("label_details", False), + "includeInterpolatedFrames": _params.get( + "interpolated_frames", False + ), + "includePredictions": _params.get("predictions", False), + "projectIds": _params.get("project_ids", None), + "modelRunIds": _params.get("model_run_ids", None), + "allProjects": _params.get("all_projects", False), + "allModelRuns": _params.get("all_model_runs", False), }, "streamable": streamable, } } search_query = build_filters(self.client, _filters) - search_query.append({ - "ids": [self.uid], - "operator": "is", - "type": "dataset" - }) + search_query.append( + {"ids": [self.uid], "operator": "is", "type": "dataset"} + ) query_params["input"]["filters"]["searchQuery"]["query"] = search_query - res = self.client.execute(create_task_query_str, - query_params, - error_log_key="errors") + res = self.client.execute( + create_task_query_str, query_params, error_log_key="errors" + ) res = res[mutation_name] task_id = res["taskId"] is_streamable = res["isStreamable"] return Task.get_task(self.client, task_id), is_streamable - def upsert_data_rows(self, - items, - file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT - ) -> "DataUpsertTask": + def upsert_data_rows( + self, items, file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT + ) -> "DataUpsertTask": """ Upserts data rows in this dataset. When "key" is provided, and it references an existing data row, an update will be performed. When "key" is not provided a new data row will be created. @@ -585,19 +612,19 @@ def upsert_data_rows(self, def _exec_upsert_data_rows( self, specs: List[DataRowItemBase], - file_upload_thread_count: int = FILE_UPLOAD_THREAD_COUNT + file_upload_thread_count: int = FILE_UPLOAD_THREAD_COUNT, ) -> "DataUpsertTask": - manifest = data_row_uploader.upload_in_chunks( client=self.client, specs=specs, file_upload_thread_count=file_upload_thread_count, - max_chunk_size_bytes=UPSERT_CHUNK_SIZE_BYTES) + max_chunk_size_bytes=UPSERT_CHUNK_SIZE_BYTES, + ) data = json.dumps(manifest.model_dump()).encode("utf-8") - manifest_uri = self.client.upload_data(data, - content_type="application/json", - filename="manifest.json") + manifest_uri = self.client.upload_data( + data, content_type="application/json", filename="manifest.json" + ) query_str = """ mutation UpsertDataRowsPyApi($manifestUri: String!) { @@ -614,44 +641,47 @@ def _exec_upsert_data_rows( return task def add_iam_integration( - self, iam_integration: Union[str, - IAMIntegration]) -> IAMIntegration: - """ - Sets the IAM integration for the dataset. IAM integration is used to sign URLs for data row assets. - - Args: - iam_integration (Union[str, IAMIntegration]): IAM integration object or IAM integration id. - - Returns: - IAMIntegration: IAM integration object. - - Raises: - LabelboxError: If the IAM integration can't be set. + self, iam_integration: Union[str, IAMIntegration] + ) -> IAMIntegration: + """ + Sets the IAM integration for the dataset. IAM integration is used to sign URLs for data row assets. - Examples: - - >>> # Get all IAM integrations - >>> iam_integrations = client.get_organization().get_iam_integrations() - >>> - >>> # Get IAM integration id - >>> iam_integration_id = [integration.uid for integration - >>> in iam_integrations - >>> if integration.name == "My S3 integration"][0] - >>> - >>> # Set IAM integration for integration id - >>> dataset.set_iam_integration(iam_integration_id) - >>> - >>> # Get IAM integration object - >>> iam_integration = [integration.uid for integration - >>> in iam_integrations - >>> if integration.name == "My S3 integration"][0] - >>> - >>> # Set IAM integration for IAMIntegrtion object - >>> dataset.set_iam_integration(iam_integration) + Args: + iam_integration (Union[str, IAMIntegration]): IAM integration object or IAM integration id. + + Returns: + IAMIntegration: IAM integration object. + + Raises: + LabelboxError: If the IAM integration can't be set. + + Examples: + + >>> # Get all IAM integrations + >>> iam_integrations = client.get_organization().get_iam_integrations() + >>> + >>> # Get IAM integration id + >>> iam_integration_id = [integration.uid for integration + >>> in iam_integrations + >>> if integration.name == "My S3 integration"][0] + >>> + >>> # Set IAM integration for integration id + >>> dataset.set_iam_integration(iam_integration_id) + >>> + >>> # Get IAM integration object + >>> iam_integration = [integration.uid for integration + >>> in iam_integrations + >>> if integration.name == "My S3 integration"][0] + >>> + >>> # Set IAM integration for IAMIntegrtion object + >>> dataset.set_iam_integration(iam_integration) """ - iam_integration_id = iam_integration.uid if isinstance( - iam_integration, IAMIntegration) else iam_integration + iam_integration_id = ( + iam_integration.uid + if isinstance(iam_integration, IAMIntegration) + else iam_integration + ) query = """ mutation SetSignerForDatasetPyApi($signerId: ID!, $datasetId: ID!) { @@ -667,29 +697,30 @@ def add_iam_integration( } """ - response = self.client.execute(query, { - "signerId": iam_integration_id, - "datasetId": self.uid - }) + response = self.client.execute( + query, {"signerId": iam_integration_id, "datasetId": self.uid} + ) if not response: - raise ResourceNotFoundError(IAMIntegration, { - "signerId": iam_integration_id, - "datasetId": self.uid - }) + raise ResourceNotFoundError( + IAMIntegration, + {"signerId": iam_integration_id, "datasetId": self.uid}, + ) try: - iam_integration_id = response.get("setSignerForDataset", - {}).get("signer", {})["id"] + iam_integration_id = response.get("setSignerForDataset", {}).get( + "signer", {} + )["id"] return [ - integration for integration in - self.client.get_organization().get_iam_integrations() + integration + for integration in self.client.get_organization().get_iam_integrations() if integration.uid == iam_integration_id ][0] except: raise LabelboxError( - f"Can't retrieve IAM integration {iam_integration_id}") + f"Can't retrieve IAM integration {iam_integration_id}" + ) def remove_iam_integration(self) -> None: """ diff --git a/libs/labelbox/src/labelbox/schema/embedding.py b/libs/labelbox/src/labelbox/schema/embedding.py index a67b82d38..dd5224c7e 100644 --- a/libs/labelbox/src/labelbox/schema/embedding.py +++ b/libs/labelbox/src/labelbox/schema/embedding.py @@ -13,6 +13,7 @@ class EmbeddingVector(BaseModel): vector (list): The raw vector values - the number of entries should match the Embedding's dimensions clusters (list): The cluster groupings """ + embedding_id: str vector: List[float] clusters: Optional[List[int]] = None @@ -37,6 +38,7 @@ class Embedding(BaseModel): dims (int): Refers to the size of the vector space in which words, phrases, or other entities are embedded custom (bool): Indicates whether the embedding is a Precomputed embedding or a Custom embedding """ + id: str name: str custom: bool @@ -54,10 +56,11 @@ def delete(self): """ self._client.delete_embedding(self.id) - def import_vectors_from_file(self, - path: str, - callback: Optional[Callable[[Dict[str, Any]], - None]] = None): + def import_vectors_from_file( + self, + path: str, + callback: Optional[Callable[[Dict[str, Any]], None]] = None, + ): """ Import vectors into a given embedding from an NDJSON file. An NDJSON file consists of newline delimited JSON. Each line of the file diff --git a/libs/labelbox/src/labelbox/schema/enums.py b/libs/labelbox/src/labelbox/schema/enums.py index c08e91bfa..6f8aebc58 100644 --- a/libs/labelbox/src/labelbox/schema/enums.py +++ b/libs/labelbox/src/labelbox/schema/enums.py @@ -2,7 +2,7 @@ class BulkImportRequestState(Enum): - """ State of the import job when importing annotations (RUNNING, FAILED, or FINISHED). + """State of the import job when importing annotations (RUNNING, FAILED, or FINISHED). If you are not usinig MEA continue using BulkImportRequest. AnnotationImports are in beta and will change soon. @@ -20,13 +20,14 @@ class BulkImportRequestState(Enum): * - FINISHED - Indicates the import job is no longer running. Check `BulkImportRequest.statuses` for more information """ + RUNNING = "RUNNING" FAILED = "FAILED" FINISHED = "FINISHED" class AnnotationImportState(Enum): - """ State of the import job when importing annotations (RUNNING, FAILED, or FINISHED). + """State of the import job when importing annotations (RUNNING, FAILED, or FINISHED). .. list-table:: :widths: 15 150 @@ -41,23 +42,25 @@ class AnnotationImportState(Enum): * - FINISHED - Indicates the import job is no longer running. Check `AnnotationImport.statuses` for more information """ + RUNNING = "RUNNING" FAILED = "FAILED" FINISHED = "FINISHED" class CollectionJobStatus(Enum): - """ Status of an asynchronous job over a collection. - - * - State - - Description - * - SUCCESS - - Indicates job has successfully processed entire collection of data - * - PARTIAL SUCCESS - - Indicates some data in the collection has succeeded and other data have failed - * - FAILURE - - Indicates job has failed to process entire collection of data + """Status of an asynchronous job over a collection. + + * - State + - Description + * - SUCCESS + - Indicates job has successfully processed entire collection of data + * - PARTIAL SUCCESS + - Indicates some data in the collection has succeeded and other data have failed + * - FAILURE + - Indicates job has failed to process entire collection of data """ + SUCCESS = "SUCCESS" PARTIAL_SUCCESS = "PARTIAL SUCCESS" - FAILURE = "FAILURE" \ No newline at end of file + FAILURE = "FAILURE" diff --git a/libs/labelbox/src/labelbox/schema/export_filters.py b/libs/labelbox/src/labelbox/schema/export_filters.py index aa97cbced..641adc011 100644 --- a/libs/labelbox/src/labelbox/schema/export_filters.py +++ b/libs/labelbox/src/labelbox/schema/export_filters.py @@ -3,6 +3,7 @@ from datetime import datetime, timezone from typing import Collection, Dict, Tuple, List, Optional from labelbox.typing_imports import Literal + if sys.version_info >= (3, 8): from typing import TypedDict else: @@ -47,8 +48,9 @@ class ProjectExportFilters(SharedExportFilters): Example: >>> ["clgo3lyax0000veeezdbu3ws4"] """ - workflow_status: Optional[Literal["ToLabel", "InReview", "InRework", - "Done"]] + workflow_status: Optional[ + Literal["ToLabel", "InReview", "InRework", "Done"] + ] """ Export data rows matching workflow status Example: >>> "InReview" @@ -68,7 +70,7 @@ class DatarowExportFilters(BaseExportFilters): def validate_datetime(datetime_str: str) -> bool: - """helper function to validate that datetime's format: "YYYY-MM-DD" or "YYYY-MM-DD hh:mm:ss" + """helper function to validate that datetime's format: "YYYY-MM-DD" or "YYYY-MM-DD hh:mm:ss" or ISO 8061 format "YYYY-MM-DDThh:mm:ss±hhmm" (Example: "2023-05-23T14:30:00+0530")""" if datetime_str: for fmt in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", ISO_8061_FORMAT): @@ -78,8 +80,7 @@ def validate_datetime(datetime_str: str) -> bool: except ValueError: pass raise ValueError(f"""Incorrect format for: {datetime_str}. - Format must be \"YYYY-MM-DD\" or \"YYYY-MM-DD hh:mm:ss\" or ISO 8061 format \"YYYY-MM-DDThh:mm:ss±hhmm\"""" - ) + Format must be \"YYYY-MM-DD\" or \"YYYY-MM-DD hh:mm:ss\" or ISO 8061 format \"YYYY-MM-DDThh:mm:ss±hhmm\"""") return True @@ -96,8 +97,10 @@ def convert_to_utc_if_iso8061(datetime_str: str, timezone_str: Optional[str]): def validate_one_of_data_row_ids_or_global_keys(filters): - if filters.get("data_row_ids") is not None and filters.get( - "global_keys") is not None: + if ( + filters.get("data_row_ids") is not None + and filters.get("global_keys") is not None + ): raise ValueError( "data_rows and global_keys cannot both be present in export filters" ) @@ -117,9 +120,11 @@ def _get_timezone() -> str: tz_res = client.execute(timezone_query_str) return tz_res["user"]["timezone"] or "UTC" - def _build_id_filters(ids: list, - type_name: str, - search_where_limit: int = SEARCH_LIMIT_PER_EXPORT_V2): + def _build_id_filters( + ids: list, + type_name: str, + search_where_limit: int = SEARCH_LIMIT_PER_EXPORT_V2, + ): if not isinstance(ids, list): raise ValueError(f"{type_name} filter expects a list.") if len(ids) == 0: @@ -136,85 +141,91 @@ def _build_id_filters(ids: list, if last_activity_at: timezone = _get_timezone() start, end = last_activity_at - if (start is not None and end is not None): + if start is not None and end is not None: [validate_datetime(date) for date in last_activity_at] start, timezone = convert_to_utc_if_iso8061(start, timezone) end, timezone = convert_to_utc_if_iso8061(end, timezone) - search_query.append({ - "type": "data_row_last_activity_at", - "value": { - "operator": "BETWEEN", - "timezone": timezone, + search_query.append( + { + "type": "data_row_last_activity_at", "value": { - "min": start, - "max": end - } + "operator": "BETWEEN", + "timezone": timezone, + "value": {"min": start, "max": end}, + }, } - }) - elif (start is not None): + ) + elif start is not None: validate_datetime(start) start, timezone = convert_to_utc_if_iso8061(start, timezone) - search_query.append({ - "type": "data_row_last_activity_at", - "value": { - "operator": "GREATER_THAN_OR_EQUAL", - "timezone": timezone, - "value": start + search_query.append( + { + "type": "data_row_last_activity_at", + "value": { + "operator": "GREATER_THAN_OR_EQUAL", + "timezone": timezone, + "value": start, + }, } - }) - elif (end is not None): + ) + elif end is not None: validate_datetime(end) end, timezone = convert_to_utc_if_iso8061(end, timezone) - search_query.append({ - "type": "data_row_last_activity_at", - "value": { - "operator": "LESS_THAN_OR_EQUAL", - "timezone": timezone, - "value": end + search_query.append( + { + "type": "data_row_last_activity_at", + "value": { + "operator": "LESS_THAN_OR_EQUAL", + "timezone": timezone, + "value": end, + }, } - }) + ) label_created_at = filters.get("label_created_at") if label_created_at: timezone = _get_timezone() start, end = label_created_at - if (start is not None and end is not None): + if start is not None and end is not None: [validate_datetime(date) for date in label_created_at] start, timezone = convert_to_utc_if_iso8061(start, timezone) end, timezone = convert_to_utc_if_iso8061(end, timezone) - search_query.append({ - "type": "labeled_at", - "value": { - "operator": "BETWEEN", - "timezone": timezone, + search_query.append( + { + "type": "labeled_at", "value": { - "min": start, - "max": end - } + "operator": "BETWEEN", + "timezone": timezone, + "value": {"min": start, "max": end}, + }, } - }) - elif (start is not None): + ) + elif start is not None: validate_datetime(start) start, timezone = convert_to_utc_if_iso8061(start, timezone) - search_query.append({ - "type": "labeled_at", - "value": { - "operator": "GREATER_THAN_OR_EQUAL", - "timezone": timezone, - "value": start + search_query.append( + { + "type": "labeled_at", + "value": { + "operator": "GREATER_THAN_OR_EQUAL", + "timezone": timezone, + "value": start, + }, } - }) - elif (end is not None): + ) + elif end is not None: validate_datetime(end) end, timezone = convert_to_utc_if_iso8061(end, timezone) - search_query.append({ - "type": "labeled_at", - "value": { - "operator": "LESS_THAN_OR_EQUAL", - "timezone": timezone, - "value": end + search_query.append( + { + "type": "labeled_at", + "value": { + "operator": "LESS_THAN_OR_EQUAL", + "timezone": timezone, + "value": end, + }, } - }) + ) data_row_ids = filters.get("data_row_ids") if data_row_ids is not None: @@ -240,9 +251,8 @@ def _build_id_filters(ids: list, if workflow_status == "ToLabel": search_query.append({"type": "task_queue_not_exist"}) else: - search_query.append({ - "type": 'task_queue_status', - "status": workflow_status - }) + search_query.append( + {"type": "task_queue_status", "status": workflow_status} + ) return search_query diff --git a/libs/labelbox/src/labelbox/schema/export_params.py b/libs/labelbox/src/labelbox/schema/export_params.py index 5229e2bfa..b15bc2828 100644 --- a/libs/labelbox/src/labelbox/schema/export_params.py +++ b/libs/labelbox/src/labelbox/schema/export_params.py @@ -5,6 +5,7 @@ EXPORT_LIMIT = 30 from labelbox.schema.media_type import MediaType + if sys.version_info >= (3, 8): from typing import TypedDict else: @@ -49,9 +50,11 @@ def _validate_array_length(array, max_length, array_name): def validate_catalog_export_params(params: CatalogExportParams): if "model_run_ids" in params and params["model_run_ids"] is not None: - _validate_array_length(params["model_run_ids"], EXPORT_LIMIT, - "model_run_ids") + _validate_array_length( + params["model_run_ids"], EXPORT_LIMIT, "model_run_ids" + ) if "project_ids" in params and params["project_ids"] is not None: - _validate_array_length(params["project_ids"], EXPORT_LIMIT, - "project_ids") + _validate_array_length( + params["project_ids"], EXPORT_LIMIT, "project_ids" + ) diff --git a/libs/labelbox/src/labelbox/schema/export_task.py b/libs/labelbox/src/labelbox/schema/export_task.py index 423e66ceb..a144f4c76 100644 --- a/libs/labelbox/src/labelbox/schema/export_task.py +++ b/libs/labelbox/src/labelbox/schema/export_task.py @@ -111,7 +111,7 @@ class JsonConverterOutput: class JsonConverter(Converter[JsonConverterOutput]): # pylint: disable=too-few-public-methods """Converts JSON data. - + Deprecated: This converter is deprecated and will be removed in a future release. """ @@ -133,16 +133,21 @@ def _find_json_object_offsets(self, data: str) -> List[Tuple[int, int]]: current_object_start = index # we need to account for scenarios where data lands in the middle of an object # and the object is not the last one in the data - if index > 0 and data[index - - 1] == "\n" and not object_offsets: + if ( + index > 0 + and data[index - 1] == "\n" + and not object_offsets + ): object_offsets.append((0, index - 1)) elif char == "}" and stack: stack.pop() # this covers cases where the last object is either followed by a newline or # it is missing - if len(stack) == 0 and (len(data) == index + 1 or - data[index + 1] == "\n" - ) and current_object_start is not None: + if ( + len(stack) == 0 + and (len(data) == index + 1 or data[index + 1] == "\n") + and current_object_start is not None + ): object_offsets.append((current_object_start, index + 1)) current_object_start = None @@ -162,7 +167,7 @@ def convert( yield JsonConverterOutput( current_offset=current_offset + offset_start, current_line=current_line + line, - json_str=raw_data[offset_start:offset_end + 1].strip(), + json_str=raw_data[offset_start : offset_end + 1].strip(), ) @@ -179,8 +184,7 @@ class FileConverterOutput: class FileConverter(Converter[FileConverterOutput]): - """Converts data to a file. - """ + """Converts data to a file.""" def __init__(self, file_path: str) -> None: super().__init__() @@ -224,8 +228,8 @@ def get_next_chunk(self) -> Optional[Tuple[_MetadataFileInfo, str]]: """Retrieves the file.""" def _get_file_content( - self, query: str, variables: dict, - result_field_name: str) -> Tuple[_MetadataFileInfo, str]: + self, query: str, variables: dict, result_field_name: str + ) -> Tuple[_MetadataFileInfo, str]: """Runs the query.""" res = self._ctx.client.execute(query, variables, error_log_key="errors") res = res["task"][result_field_name] @@ -233,14 +237,17 @@ def _get_file_content( if not file_info: raise ValueError( f"Task {self._ctx.task_id} does not have a metadata file for the " - f"{self._ctx.stream_type.value} stream") + f"{self._ctx.stream_type.value} stream" + ) response = requests.get(file_info.file, timeout=30) response.raise_for_status() - assert len( - response.content - ) == file_info.offsets.end - file_info.offsets.start + 1, ( + assert ( + len(response.content) + == file_info.offsets.end - file_info.offsets.start + 1 + ), ( f"expected {file_info.offsets.end - file_info.offsets.start + 1} bytes, " - f"got {len(response.content)} bytes") + f"got {len(response.content)} bytes" + ) return file_info, response.text @@ -260,8 +267,9 @@ def __init__( f"offset is out of range, max offset is {self._ctx.metadata_header.total_size - 1}" ) - def _find_line_at_offset(self, file_content: str, - target_offset: int) -> int: + def _find_line_at_offset( + self, file_content: str, target_offset: int + ) -> int: # TODO: Remove this, incorrect parsing of JSON to find braces stack = [] line_number = 0 @@ -288,22 +296,24 @@ def get_next_chunk(self) -> Optional[Tuple[_MetadataFileInfo, str]]: f"{{task(where: $where)" f"{{{'exportFileFromOffset'}(streamType: $streamType, offset: $offset)" f"{{offsets {{start end}} lines {{start end}} file}}" - f"}}}}") + f"}}}}" + ) variables = { - "where": { - "id": self._ctx.task_id - }, + "where": {"id": self._ctx.task_id}, "streamType": self._ctx.stream_type.value, "offset": str(self._current_offset), } file_info, file_content = self._get_file_content( - query, variables, "exportFileFromOffset") + query, variables, "exportFileFromOffset" + ) if self._current_line is None: self._current_line = self._find_line_at_offset( - file_content, self._current_offset - file_info.offsets.start) + file_content, self._current_offset - file_info.offsets.start + ) self._current_line += file_info.lines.start - file_content = file_content[self._current_offset - - file_info.offsets.start:] + file_content = file_content[ + self._current_offset - file_info.offsets.start : + ] file_info.offsets.start = self._current_offset file_info.lines.start = self._current_line self._current_offset = file_info.offsets.end + 1 @@ -357,22 +367,24 @@ def get_next_chunk(self) -> Optional[Tuple[_MetadataFileInfo, str]]: f"{{task(where: $where)" f"{{{'exportFileFromLine'}(streamType: $streamType, line: $line)" f"{{offsets {{start end}} lines {{start end}} file}}" - f"}}}}") + f"}}}}" + ) variables = { - "where": { - "id": self._ctx.task_id - }, + "where": {"id": self._ctx.task_id}, "streamType": self._ctx.stream_type.value, "line": self._current_line, } file_info, file_content = self._get_file_content( - query, variables, "exportFileFromLine") + query, variables, "exportFileFromLine" + ) if self._current_offset is None: self._current_offset = self._find_offset_of_line( - file_content, self._current_line - file_info.lines.start) + file_content, self._current_line - file_info.lines.start + ) self._current_offset += file_info.offsets.start - file_content = file_content[self._current_offset - - file_info.offsets.start:] + file_content = file_content[ + self._current_offset - file_info.offsets.start : + ] file_info.offsets.start = self._current_offset file_info.lines.start = self._current_line self._current_offset = file_info.offsets.end + 1 @@ -394,7 +406,7 @@ def read(self) -> Iterator[Tuple[_MetadataFileInfo, str]]: class _MultiGCSFileReader(_Reader): # pylint: disable=too-few-public-methods """Reads data from multiple GCS files in a seamless way. - + Deprecated: This reader is deprecated and will be removed in a future release. """ @@ -437,7 +449,9 @@ def __init__( def __iter__(self): yield from self._fetch() - def _fetch(self,) -> Iterator[OutputT]: + def _fetch( + self, + ) -> Iterator[OutputT]: """Fetches the result data. Returns an iterator that yields the offset and the data. """ @@ -448,25 +462,27 @@ def _fetch(self,) -> Iterator[OutputT]: with self._converter as converter: for file_info, raw_data in stream: for output in converter.convert( - Converter.ConverterInputArgs(self._ctx, file_info, - raw_data)): + Converter.ConverterInputArgs(self._ctx, file_info, raw_data) + ): yield output def with_offset(self, offset: int) -> "Stream[OutputT]": """Sets the offset for the stream.""" self._reader.set_retrieval_strategy( - FileRetrieverByOffset(self._ctx, offset)) + FileRetrieverByOffset(self._ctx, offset) + ) return self def with_line(self, line: int) -> "Stream[OutputT]": """Sets the line number for the stream.""" - self._reader.set_retrieval_strategy(FileRetrieverByLine( - self._ctx, line)) + self._reader.set_retrieval_strategy( + FileRetrieverByLine(self._ctx, line) + ) return self def start( - self, - stream_handler: Optional[Callable[[OutputT], None]] = None) -> None: + self, stream_handler: Optional[Callable[[OutputT], None]] = None + ) -> None: """Starts streaming the result data. Calls the stream_handler for each result. """ @@ -501,16 +517,16 @@ def get_next_chunk(self) -> Optional[Tuple[_MetadataFileInfo, str]]: f"{{task(where: $where)" f"{{{'exportFileFromOffset'}(streamType: $streamType, offset: $offset)" f"{{offsets {{start end}} lines {{start end}} file}}" - f"}}}}") + f"}}}}" + ) variables = { - "where": { - "id": self._ctx.task_id - }, + "where": {"id": self._ctx.task_id}, "streamType": self._ctx.stream_type.value, "offset": str(self._current_offset), } file_info, file_content = self._get_file_content( - query, variables, "exportFileFromOffset") + query, variables, "exportFileFromOffset" + ) file_info.offsets.start = self._current_offset file_info.lines.start = self._current_line self._current_offset = file_info.offsets.end + 1 @@ -529,12 +545,15 @@ def __init__( self._reader = _BufferedGCSFileReader() self._converter = _BufferedJsonConverter() self._reader.set_retrieval_strategy( - _BufferedFileRetrieverByOffset(self._ctx, 0)) + _BufferedFileRetrieverByOffset(self._ctx, 0) + ) def __iter__(self): yield from self._fetch() - def _fetch(self,) -> Iterator[OutputT]: + def _fetch( + self, + ) -> Iterator[OutputT]: """Fetches the result data. Returns an iterator that yields the offset and the data. """ @@ -545,13 +564,13 @@ def _fetch(self,) -> Iterator[OutputT]: with self._converter as converter: for file_info, raw_data in stream: for output in converter.convert( - Converter.ConverterInputArgs(self._ctx, file_info, - raw_data)): + Converter.ConverterInputArgs(self._ctx, file_info, raw_data) + ): yield output def start( - self, - stream_handler: Optional[Callable[[OutputT], None]] = None) -> None: + self, stream_handler: Optional[Callable[[OutputT], None]] = None + ) -> None: """Starts streaming the result data. Calls the stream_handler for each result. """ @@ -564,12 +583,12 @@ def start( @dataclass class BufferedJsonConverterOutput: """Output with the JSON object""" + json: Any class _BufferedJsonConverter(Converter[BufferedJsonConverterOutput]): - """Converts JSON data in a buffered manner - """ + """Converts JSON data in a buffered manner""" def convert( self, input_args: Converter.ConverterInputArgs @@ -592,7 +611,7 @@ def read(self) -> Iterator[Tuple[_MetadataFileInfo, str]]: if not self._retrieval_strategy: raise ValueError("retrieval strategy not set") # create a buffer - with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp_file: + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file: result = self._retrieval_strategy.get_next_chunk() while result: _, raw_data = result @@ -604,12 +623,16 @@ def read(self) -> Iterator[Tuple[_MetadataFileInfo, str]]: temp_file.write(raw_data) result = self._retrieval_strategy.get_next_chunk() # read buffer - with open(temp_file.name, 'r') as temp_file_reopened: + with open(temp_file.name, "r") as temp_file_reopened: for idx, line in enumerate(temp_file_reopened): - yield _MetadataFileInfo(offsets=Range(start=0, - end=len(line) - 1), - lines=Range(start=idx, end=idx + 1), - file=temp_file.name), line + yield ( + _MetadataFileInfo( + offsets=Range(start=0, end=len(line) - 1), + lines=Range(start=idx, end=idx + 1), + file=temp_file.name, + ), + line, + ) # manually delete buffer os.unlink(temp_file.name) @@ -632,8 +655,11 @@ def __init__(self, task: Task, is_export_v2: bool = False) -> None: self._task = task def __repr__(self): - return f"" if getattr( - self, "uid", None) else "" + return ( + f"" + if getattr(self, "uid", None) + else "" + ) def __str__(self): properties_to_include = [ @@ -702,8 +728,13 @@ def result_url(self): "This property is only available for export_v2 tasks due to compatibility reasons, please use streamable errors instead" ) base_url = self._task.client.rest_endpoint - return base_url + '/export-results/' + self._task.uid + '/' + self._task.client.get_organization( - ).uid + return ( + base_url + + "/export-results/" + + self._task.uid + + "/" + + self._task.client.get_organization().uid + ) @property def errors_url(self): @@ -715,8 +746,13 @@ def errors_url(self): if not self.has_errors(): return None base_url = self._task.client.rest_endpoint - return base_url + '/export-errors/' + self._task.uid + '/' + self._task.client.get_organization( - ).uid + return ( + base_url + + "/export-errors/" + + self._task.uid + + "/" + + self._task.client.get_organization().uid + ) @property def errors(self): @@ -736,14 +772,18 @@ def errors(self): data = [] metadata_header = ExportTask._get_metadata_header( - self._task.client, self._task.uid, StreamType.ERRORS) + self._task.client, self._task.uid, StreamType.ERRORS + ) if metadata_header is None: return None BufferedStream( _TaskContext( - self._task.client, self._task.uid, StreamType.ERRORS, - metadata_header),).start( - stream_handler=lambda output: data.append(output.json)) + self._task.client, + self._task.uid, + StreamType.ERRORS, + metadata_header, + ), + ).start(stream_handler=lambda output: data.append(output.json)) return data @property @@ -757,14 +797,18 @@ def result(self): data = [] metadata_header = ExportTask._get_metadata_header( - self._task.client, self._task.uid, StreamType.RESULT) + self._task.client, self._task.uid, StreamType.RESULT + ) if metadata_header is None: return [] BufferedStream( _TaskContext( - self._task.client, self._task.uid, StreamType.RESULT, - metadata_header),).start( - stream_handler=lambda output: data.append(output.json)) + self._task.client, + self._task.uid, + StreamType.RESULT, + metadata_header, + ), + ).start(stream_handler=lambda output: data.append(output.json)) return data return self._task.result_url @@ -798,15 +842,17 @@ def wait_till_done(self, timeout_seconds: int = 7200) -> None: @staticmethod @lru_cache(maxsize=5) def _get_metadata_header( - client, task_id: str, - stream_type: StreamType) -> Union[_MetadataHeader, None]: + client, task_id: str, stream_type: StreamType + ) -> Union[_MetadataHeader, None]: """Returns the total file size for a specific task.""" - query = (f"query GetExportMetadataHeaderPyApi" - f"($where: WhereUniqueIdInput, $streamType: TaskStreamType!)" - f"{{task(where: $where)" - f"{{{'exportMetadataHeader'}(streamType: $streamType)" - f"{{totalSize totalLines}}" - f"}}}}") + query = ( + f"query GetExportMetadataHeaderPyApi" + f"($where: WhereUniqueIdInput, $streamType: TaskStreamType!)" + f"{{task(where: $where)" + f"{{{'exportMetadataHeader'}(streamType: $streamType)" + f"{{totalSize totalLines}}" + f"}}}}" + ) variables = {"where": {"id": task_id}, "streamType": stream_type.value} res = client.execute(query, variables, error_log_key="errors") res = res["task"]["exportMetadataHeader"] @@ -818,8 +864,9 @@ def get_total_file_size(self, stream_type: StreamType) -> Union[int, None]: raise ExportTask.ExportTaskException("Task failed") if self._task.status != "COMPLETE": raise ExportTask.ExportTaskException("Task is not ready yet") - header = ExportTask._get_metadata_header(self._task.client, - self._task.uid, stream_type) + header = ExportTask._get_metadata_header( + self._task.client, self._task.uid, stream_type + ) return header.total_size if header else None def get_total_lines(self, stream_type: StreamType) -> Union[int, None]: @@ -828,8 +875,9 @@ def get_total_lines(self, stream_type: StreamType) -> Union[int, None]: raise ExportTask.ExportTaskException("Task failed") if self._task.status != "COMPLETE": raise ExportTask.ExportTaskException("Task is not ready yet") - header = ExportTask._get_metadata_header(self._task.client, - self._task.uid, stream_type) + header = ExportTask._get_metadata_header( + self._task.client, self._task.uid, stream_type + ) return header.total_lines if header else None def has_result(self) -> bool: @@ -864,15 +912,18 @@ def get_buffered_stream( if self._task.status != "COMPLETE": raise ExportTask.ExportTaskException("Task is not ready yet") - metadata_header = self._get_metadata_header(self._task.client, - self._task.uid, stream_type) + metadata_header = self._get_metadata_header( + self._task.client, self._task.uid, stream_type + ) if metadata_header is None: raise ValueError( f"Task {self._task.uid} does not have a {stream_type.value} stream" ) return BufferedStream( - _TaskContext(self._task.client, self._task.uid, stream_type, - metadata_header),) + _TaskContext( + self._task.client, self._task.uid, stream_type, metadata_header + ), + ) @overload def get_stream( @@ -906,15 +957,17 @@ def get_stream( if self._task.status != "COMPLETE": raise ExportTask.ExportTaskException("Task is not ready yet") - metadata_header = self._get_metadata_header(self._task.client, - self._task.uid, stream_type) + metadata_header = self._get_metadata_header( + self._task.client, self._task.uid, stream_type + ) if metadata_header is None: raise ValueError( f"Task {self._task.uid} does not have a {stream_type.value} stream" ) return Stream( - _TaskContext(self._task.client, self._task.uid, stream_type, - metadata_header), + _TaskContext( + self._task.client, self._task.uid, stream_type, metadata_header + ), _MultiGCSFileReader(), converter, ) @@ -923,4 +976,3 @@ def get_stream( def get_task(client, task_id): """Returns the task with the given id.""" return ExportTask(Task.get_task(client, task_id)) - \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/schema/foundry/app.py b/libs/labelbox/src/labelbox/schema/foundry/app.py index f73d5056f..2886dec15 100644 --- a/libs/labelbox/src/labelbox/schema/foundry/app.py +++ b/libs/labelbox/src/labelbox/schema/foundry/app.py @@ -13,7 +13,7 @@ class App(_CamelCaseMixin): class_to_schema_id: Dict[str, str] ontology_id: str created_by: Optional[str] = None - + model_config = ConfigDict(protected_namespaces=()) @classmethod @@ -21,4 +21,4 @@ def type_name(cls): return "App" -APP_FIELD_NAMES = list(App.model_json_schema()['properties'].keys()) +APP_FIELD_NAMES = list(App.model_json_schema()["properties"].keys()) diff --git a/libs/labelbox/src/labelbox/schema/foundry/foundry_client.py b/libs/labelbox/src/labelbox/schema/foundry/foundry_client.py index 27d577bc0..914a363c7 100644 --- a/libs/labelbox/src/labelbox/schema/foundry/foundry_client.py +++ b/libs/labelbox/src/labelbox/schema/foundry/foundry_client.py @@ -6,7 +6,6 @@ class FoundryClient: - def __init__(self, client): self.client = client @@ -35,7 +34,7 @@ def _create_app(self, app: App) -> App: try: response = self.client.execute(query_str, params) except exceptions.LabelboxError as e: - raise exceptions.LabelboxError('Unable to create app', e) + raise exceptions.LabelboxError("Unable to create app", e) return App(**response["createModelFoundryApp"]) def _get_app(self, id: str) -> App: @@ -55,7 +54,7 @@ def _get_app(self, id: str) -> App: except exceptions.InvalidQueryError as e: raise exceptions.ResourceNotFoundError(App, params) except Exception as e: - raise exceptions.LabelboxError(f'Unable to get app with id {id}', e) + raise exceptions.LabelboxError(f"Unable to get app with id {id}", e) return App(**response["findModelFoundryApp"]) def _delete_app(self, id: str) -> None: @@ -70,11 +69,16 @@ def _delete_app(self, id: str) -> None: try: self.client.execute(query_str, params) except Exception as e: - raise exceptions.LabelboxError(f'Unable to delete app with id {id}', - e) + raise exceptions.LabelboxError( + f"Unable to delete app with id {id}", e + ) - def run_app(self, model_run_name: str, - data_rows: Union[DataRowIds, GlobalKeys], app_id: str) -> Task: + def run_app( + self, + model_run_name: str, + data_rows: Union[DataRowIds, GlobalKeys], + app_id: str, + ) -> Task: app = self._get_app(app_id) params = { @@ -82,10 +86,14 @@ def run_app(self, model_run_name: str, "name": model_run_name, "classToSchemaId": app.class_to_schema_id, "inferenceParams": app.inference_params, - "ontologyId": app.ontology_id + "ontologyId": app.ontology_id, } - data_rows_key = "dataRowIds" if data_rows.id_type == IdType.DataRowId else "globalKeys" + data_rows_key = ( + "dataRowIds" + if data_rows.id_type == IdType.DataRowId + else "globalKeys" + ) params[data_rows_key] = list(data_rows) query = """ @@ -99,6 +107,6 @@ def run_app(self, model_run_name: str, try: response = self.client.execute(query, {"input": params}) except Exception as e: - raise exceptions.LabelboxError('Unable to run foundry app', e) + raise exceptions.LabelboxError("Unable to run foundry app", e) task_id = response["createModelJobForDataRows"]["taskId"] return Task.get_task(self.client, task_id) diff --git a/libs/labelbox/src/labelbox/schema/foundry/model.py b/libs/labelbox/src/labelbox/schema/foundry/model.py index 87fda22f2..6c2ab6d88 100644 --- a/libs/labelbox/src/labelbox/schema/foundry/model.py +++ b/libs/labelbox/src/labelbox/schema/foundry/model.py @@ -15,4 +15,4 @@ class Model(_CamelCaseMixin, BaseModel): created_at: datetime -MODEL_FIELD_NAMES = list(Model.model_json_schema()['properties'].keys()) +MODEL_FIELD_NAMES = list(Model.model_json_schema()["properties"].keys()) diff --git a/libs/labelbox/src/labelbox/schema/iam_integration.py b/libs/labelbox/src/labelbox/schema/iam_integration.py index 00c4f0ae9..cb5309929 100644 --- a/libs/labelbox/src/labelbox/schema/iam_integration.py +++ b/libs/labelbox/src/labelbox/schema/iam_integration.py @@ -17,7 +17,7 @@ class GcpIamIntegrationSettings: class IAMIntegration(DbObject): - """ Represents an IAM integration for delegated access + """Represents an IAM integration for delegated access Attributes: name (str) @@ -31,9 +31,9 @@ class IAMIntegration(DbObject): """ def __init__(self, client, data): - settings = data.pop('settings', None) + settings = data.pop("settings", None) if settings is not None: - type_name = settings.pop('__typename') + type_name = settings.pop("__typename") settings = {snake_case(k): v for k, v in settings.items()} if type_name == "GcpIamIntegrationSettings": self.settings = GcpIamIntegrationSettings(**settings) diff --git a/libs/labelbox/src/labelbox/schema/id_type.py b/libs/labelbox/src/labelbox/schema/id_type.py index a78dc572c..3ecad4ca1 100644 --- a/libs/labelbox/src/labelbox/schema/id_type.py +++ b/libs/labelbox/src/labelbox/schema/id_type.py @@ -15,10 +15,11 @@ class BaseStrEnum(str, Enum): class IdType(BaseStrEnum): """ The type of id used to identify a data row. - + Currently supported types are: - DataRowId: The id assigned to a data row by Labelbox. - GlobalKey: The id assigned to a data row by the user. """ + DataRowId = "ID" GlobalKey = "GKEY" diff --git a/libs/labelbox/src/labelbox/schema/identifiables.py b/libs/labelbox/src/labelbox/schema/identifiables.py index 73a6c4bb3..590ac70c9 100644 --- a/libs/labelbox/src/labelbox/schema/identifiables.py +++ b/libs/labelbox/src/labelbox/schema/identifiables.py @@ -4,7 +4,6 @@ class Identifiables: - def __init__(self, iterable, id_type: str): """ Args: @@ -36,7 +35,10 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: if not isinstance(other, Identifiables): return False - return self._iterable == other._iterable and self._id_type == other._id_type + return ( + self._iterable == other._iterable + and self._id_type == other._id_type + ) class UniqueIds(Identifiables): diff --git a/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py b/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py index 62962d70d..817a02561 100644 --- a/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py +++ b/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py @@ -2,8 +2,14 @@ from typing import List -from labelbox.schema.internal.data_row_upsert_item import DataRowItemBase, DataRowUpsertItem, DataRowCreateItem -from labelbox.schema.internal.descriptor_file_creator import DescriptorFileCreator +from labelbox.schema.internal.data_row_upsert_item import ( + DataRowItemBase, + DataRowUpsertItem, + DataRowCreateItem, +) +from labelbox.schema.internal.descriptor_file_creator import ( + DescriptorFileCreator, +) from pydantic import BaseModel @@ -16,22 +22,27 @@ class UploadManifest(BaseModel): SOURCE_SDK = "SDK" -def upload_in_chunks(client, specs: List[DataRowItemBase], - file_upload_thread_count: int, - max_chunk_size_bytes: int) -> UploadManifest: +def upload_in_chunks( + client, + specs: List[DataRowItemBase], + file_upload_thread_count: int, + max_chunk_size_bytes: int, +) -> UploadManifest: empty_specs = list(filter(lambda spec: spec.is_empty(), specs)) if empty_specs: ids = list(map(lambda spec: spec.id.get("value"), empty_specs)) ids = list(filter(lambda x: x is not None and len(x) > 0, ids)) if len(ids) > 0: raise ValueError( - f"The following items have an empty payload: {ids}") + f"The following items have an empty payload: {ids}" + ) else: # case of create items raise ValueError("Some items have an empty payload") chunk_uris = DescriptorFileCreator(client).create( - specs, max_chunk_size_bytes=max_chunk_size_bytes) + specs, max_chunk_size_bytes=max_chunk_size_bytes + ) - return UploadManifest(source=SOURCE_SDK, - item_count=len(specs), - chunk_uris=chunk_uris) + return UploadManifest( + source=SOURCE_SDK, item_count=len(specs), chunk_uris=chunk_uris + ) diff --git a/libs/labelbox/src/labelbox/schema/internal/data_row_upsert_item.py b/libs/labelbox/src/labelbox/schema/internal/data_row_upsert_item.py index 5759ca818..cc9bbb2c3 100644 --- a/libs/labelbox/src/labelbox/schema/internal/data_row_upsert_item.py +++ b/libs/labelbox/src/labelbox/schema/internal/data_row_upsert_item.py @@ -16,30 +16,30 @@ class DataRowItemBase(ABC, BaseModel): payload: dict @abstractmethod - def is_empty(self) -> bool: - ... + def is_empty(self) -> bool: ... @classmethod def build( cls, dataset_id: str, items: List[dict], - key_types: Optional[Tuple[type, ...]] = () + key_types: Optional[Tuple[type, ...]] = (), ) -> List["DataRowItemBase"]: upload_items = [] for item in items: # enforce current dataset's id for all specs - item['dataset_id'] = dataset_id - key = item.pop('key', None) + item["dataset_id"] = dataset_id + key = item.pop("key", None) if not key: - key = {'type': 'AUTO', 'value': ''} + key = {"type": "AUTO", "value": ""} elif isinstance(key, key_types): # type: ignore - key = {'type': key.id_type.value, 'value': key.key} + key = {"type": key.id_type.value, "value": key.key} else: if not key_types: raise ValueError( - f"Can not have a key for this item, got: {key}") + f"Can not have a key for this item, got: {key}" + ) raise ValueError( f"Key must be an instance of {', '.join([t.__name__ for t in key_types])}, got: {type(item['key']).__name__}" ) @@ -51,27 +51,28 @@ def build( class DataRowUpsertItem(DataRowItemBase): - def is_empty(self) -> bool: """ The payload is considered empty if it's actually empty or the only key is `dataset_id`. :return: bool """ - return (not self.payload or - len(self.payload.keys()) == 1 and "dataset_id" in self.payload) + return ( + not self.payload + or len(self.payload.keys()) == 1 + and "dataset_id" in self.payload + ) @classmethod def build( cls, dataset_id: str, items: List[dict], - key_types: Optional[Tuple[type, ...]] = (UniqueId, GlobalKey) + key_types: Optional[Tuple[type, ...]] = (UniqueId, GlobalKey), ) -> List["DataRowItemBase"]: return super().build(dataset_id, items, (UniqueId, GlobalKey)) class DataRowCreateItem(DataRowItemBase): - def is_empty(self) -> bool: """ The payload is considered empty if it's actually empty or row_data is empty @@ -79,22 +80,28 @@ def is_empty(self) -> bool: :return: bool """ row_data = self.payload.get("row_data", None) or self.payload.get( - DataRow.row_data, None) + DataRow.row_data, None + ) - return (not self._is_legacy_conversational_data() and - (not self.payload or len(self.payload.keys()) == 1 and - "dataset_id" in self.payload or row_data is None or - len(row_data) == 0)) + return not self._is_legacy_conversational_data() and ( + not self.payload + or len(self.payload.keys()) == 1 + and "dataset_id" in self.payload + or row_data is None + or len(row_data) == 0 + ) def _is_legacy_conversational_data(self) -> bool: - return "conversationalData" in self.payload.keys( - ) or "conversational_data" in self.payload.keys() + return ( + "conversationalData" in self.payload.keys() + or "conversational_data" in self.payload.keys() + ) @classmethod def build( cls, dataset_id: str, items: List[dict], - key_types: Optional[Tuple[type, ...]] = () + key_types: Optional[Tuple[type, ...]] = (), ) -> List["DataRowItemBase"]: return super().build(dataset_id, items, ()) diff --git a/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py b/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py index 07128fdd1..ce3ce4b35 100644 --- a/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py +++ b/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py @@ -12,10 +12,15 @@ from labelbox.orm.model import Field from labelbox.schema.embedding import EmbeddingVector from labelbox.schema.internal.datarow_upload_constants import ( - FILE_UPLOAD_THREAD_COUNT) -from labelbox.schema.internal.data_row_upsert_item import DataRowItemBase, DataRowUpsertItem + FILE_UPLOAD_THREAD_COUNT, +) +from labelbox.schema.internal.data_row_upsert_item import ( + DataRowItemBase, + DataRowUpsertItem, +) from typing import TYPE_CHECKING + if TYPE_CHECKING: from labelbox import Client @@ -40,19 +45,25 @@ def create(self, items, max_chunk_size_bytes=None) -> List[str]: json_chunks = self._chunk_down_by_bytes(items, max_chunk_size_bytes) with ThreadPoolExecutor(FILE_UPLOAD_THREAD_COUNT) as executor: futures = [ - executor.submit(self.client.upload_data, chunk, - "application/json", "json_import.json") + executor.submit( + self.client.upload_data, + chunk, + "application/json", + "json_import.json", + ) for chunk in json_chunks ] return [future.result() for future in as_completed(futures)] def create_one(self, items) -> List[str]: - items = self._prepare_items_for_upload(items,) + items = self._prepare_items_for_upload( + items, + ) # Prepare and upload the descriptor file data = json.dumps(items) - return self.client.upload_data(data, - content_type="application/json", - filename="json_import.json") + return self.client.upload_data( + data, content_type="application/json", filename="json_import.json" + ) def _prepare_items_for_upload(self, items, is_upsert=False): """ @@ -99,20 +110,20 @@ def _prepare_items_for_upload(self, items, is_upsert=False): AssetAttachment = Entity.AssetAttachment def upload_if_necessary(item): - if is_upsert and 'row_data' not in item: + if is_upsert and "row_data" not in item: # When upserting, row_data is not required return item - row_data = item['row_data'] + row_data = item["row_data"] if isinstance(row_data, str) and os.path.exists(row_data): item_url = self.client.upload_file(row_data) - item['row_data'] = item_url - if 'external_id' not in item: + item["row_data"] = item_url + if "external_id" not in item: # Default `external_id` to local file name - item['external_id'] = row_data + item["external_id"] = row_data return item def validate_attachments(item): - attachments = item.get('attachments') + attachments = item.get("attachments") if attachments: if isinstance(attachments, list): for attachment in attachments: @@ -139,18 +150,25 @@ def validate_conversational_data(conversational_data: list) -> None: """ def check_message_keys(message): - accepted_message_keys = set([ - "messageId", "timestampUsec", "content", "user", "align", - "canLabel" - ]) + accepted_message_keys = set( + [ + "messageId", + "timestampUsec", + "content", + "user", + "align", + "canLabel", + ] + ) for key in message.keys(): if not key in accepted_message_keys: raise KeyError( f"Invalid {key} key found! Accepted keys in messages list is {accepted_message_keys}" ) - if conversational_data and not isinstance(conversational_data, - list): + if conversational_data and not isinstance( + conversational_data, list + ): raise ValueError( f"conversationalData must be a list. Found {type(conversational_data)}" ) @@ -158,11 +176,12 @@ def check_message_keys(message): [check_message_keys(message) for message in conversational_data] def parse_metadata_fields(item): - metadata_fields = item.get('metadata_fields') + metadata_fields = item.get("metadata_fields") if metadata_fields: mdo = self.client.get_data_row_metadata_ontology() - item['metadata_fields'] = mdo.parse_upsert_metadata( - metadata_fields) + item["metadata_fields"] = mdo.parse_upsert_metadata( + metadata_fields + ) def format_row(item): # Formats user input into a consistent dict structure @@ -182,19 +201,28 @@ def format_row(item): return item def validate_keys(item): - if not is_upsert and 'row_data' not in item: + if not is_upsert and "row_data" not in item: raise InvalidQueryError( - "`row_data` missing when creating DataRow.") + "`row_data` missing when creating DataRow." + ) - if isinstance(item.get('row_data'), - str) and item.get('row_data').startswith("s3:/"): + if isinstance(item.get("row_data"), str) and item.get( + "row_data" + ).startswith("s3:/"): raise InvalidQueryError( - "row_data: s3 assets must start with 'https'.") + "row_data: s3 assets must start with 'https'." + ) allowed_extra_fields = { - 'attachments', 'media_type', 'dataset_id', 'embeddings' + "attachments", + "media_type", + "dataset_id", + "embeddings", } - invalid_keys = set(item) - {f.name for f in DataRow.fields() - } - allowed_extra_fields + invalid_keys = ( + set(item) + - {f.name for f in DataRow.fields()} + - allowed_extra_fields + ) if invalid_keys: raise InvalidAttributeError(DataRow, invalid_keys) return item @@ -210,12 +238,11 @@ def format_legacy_conversational_data(item): global_key = item.pop("globalKey") item["globalKey"] = global_key validate_conversational_data(messages) - one_conversation = \ - { - "type": type, - "version": version, - "messages": messages - } + one_conversation = { + "type": type, + "version": version, + "messages": messages, + } item["row_data"] = one_conversation return item @@ -246,7 +273,7 @@ def convert_item(data_row_item): item = upload_if_necessary(item) if isinstance(data_row_item, DataRowItemBase): - return {'id': data_row_item.id, 'payload': item} + return {"id": data_row_item.id, "payload": item} else: return item @@ -261,8 +288,9 @@ def convert_item(data_row_item): return items - def _chunk_down_by_bytes(self, items: List[dict], - max_chunk_size: int) -> Generator[str, None, None]: + def _chunk_down_by_bytes( + self, items: List[dict], max_chunk_size: int + ) -> Generator[str, None, None]: """ Recursively chunks down a list of items into smaller lists until each list is less than or equal to max_chunk_size bytes NOTE: if one data row is larger than max_chunk_size, it will be returned as one chunk diff --git a/libs/labelbox/src/labelbox/schema/invite.py b/libs/labelbox/src/labelbox/schema/invite.py index 266e14c7f..c89a8b08c 100644 --- a/libs/labelbox/src/labelbox/schema/invite.py +++ b/libs/labelbox/src/labelbox/schema/invite.py @@ -22,6 +22,7 @@ class Invite(DbObject): """ An object representing a user invite """ + created_at = Field.DateTime("created_at") organization_role_name = Field.String("organization_role_name") email = Field.String("email", "inviteeEmail") @@ -31,7 +32,9 @@ def __init__(self, client, invite_response): super().__init__(client, invite_response) self.project_roles = [ - ProjectRole(project=client.get_project(r['projectId']), - role=client.get_roles()[format_role( - r['projectRoleName'])]) for r in project_roles + ProjectRole( + project=client.get_project(r["projectId"]), + role=client.get_roles()[format_role(r["projectRoleName"])], + ) + for r in project_roles ] diff --git a/libs/labelbox/src/labelbox/schema/label.py b/libs/labelbox/src/labelbox/schema/label.py index 7a7d2dc51..371193a13 100644 --- a/libs/labelbox/src/labelbox/schema/label.py +++ b/libs/labelbox/src/labelbox/schema/label.py @@ -10,7 +10,7 @@ class Label(DbObject, Updateable, BulkDeletable): - """ Label represents an assessment on a DataRow. For example one label could + """Label represents an assessment on a DataRow. For example one label could contain 100 bounding boxes (annotations). Attributes: @@ -46,7 +46,7 @@ def __init__(self, *args, **kwargs): @staticmethod def bulk_delete(labels) -> None: - """ Deletes all the given Labels. + """Deletes all the given Labels. Args: labels (list of Label): The Labels to delete. @@ -54,7 +54,7 @@ def bulk_delete(labels) -> None: BulkDeletable._bulk_delete(labels, False) def create_review(self, **kwargs) -> "Review": - """ Creates a Review for this label. + """Creates a Review for this label. Args: **kwargs: Review attributes. At a minimum, a `Review.score` field value must be provided. @@ -64,7 +64,7 @@ def create_review(self, **kwargs) -> "Review": return self.client._create(Entity.Review, kwargs) def create_benchmark(self) -> "Benchmark": - """ Creates a Benchmark for this Label. + """Creates a Benchmark for this Label. Returns: The newly created Benchmark. @@ -72,7 +72,9 @@ def create_benchmark(self) -> "Benchmark": label_id_param = "labelId" query_str = """mutation CreateBenchmarkPyApi($%s: ID!) { createBenchmark(data: {labelId: $%s}) {%s}} """ % ( - label_id_param, label_id_param, - query.results_query_part(Entity.Benchmark)) + label_id_param, + label_id_param, + query.results_query_part(Entity.Benchmark), + ) res = self.client.execute(query_str, {label_id_param: self.uid}) return Entity.Benchmark(self.client, res["createBenchmark"]) diff --git a/libs/labelbox/src/labelbox/schema/labeling_frontend.py b/libs/labelbox/src/labelbox/schema/labeling_frontend.py index 147148ece..49bc8825f 100644 --- a/libs/labelbox/src/labelbox/schema/labeling_frontend.py +++ b/libs/labelbox/src/labelbox/schema/labeling_frontend.py @@ -3,7 +3,7 @@ class LabelingFrontend(DbObject): - """ Label editor. + """Label editor. Represents an HTML / JavaScript UI that is used to generate labels. “Editor” is the default Labeling Frontend that comes in every @@ -16,13 +16,14 @@ class LabelingFrontend(DbObject): projects (Relationship): `ToMany` relationship to Project """ + name = Field.String("name") description = Field.String("description") iframe_url_path = Field.String("iframe_url_path") class LabelingFrontendOptions(DbObject): - """ Label interface options. + """Label interface options. Attributes: customization_options (str) @@ -31,6 +32,7 @@ class LabelingFrontendOptions(DbObject): labeling_frontend (Relationship): `ToOne` relationship to LabelingFrontend organization (Relationship): `ToOne` relationship to Organization """ + customization_options = Field.String("customization_options") project = Relationship.ToOne("Project") diff --git a/libs/labelbox/src/labelbox/schema/labeling_service.py b/libs/labelbox/src/labelbox/schema/labeling_service.py index 70376f2e8..a7a1845be 100644 --- a/libs/labelbox/src/labelbox/schema/labeling_service.py +++ b/libs/labelbox/src/labelbox/schema/labeling_service.py @@ -16,6 +16,7 @@ class LabelingService(_CamelCaseMixin): """ Labeling service for a project. This is a service that can be requested to label data for a project. """ + id: Cuid project_id: Cuid created_at: datetime @@ -28,10 +29,11 @@ def __init__(self, **kwargs): super().__init__(**kwargs) if not self.client.enable_experimental: raise RuntimeError( - "Please enable experimental in client to use LabelingService") + "Please enable experimental in client to use LabelingService" + ) @classmethod - def start(cls, client, project_id: Cuid) -> 'LabelingService': + def start(cls, client, project_id: Cuid) -> "LabelingService": """ Starts the labeling service for the project. This is equivalent to a UI action to Request Specialized Labelers @@ -52,7 +54,7 @@ def start(cls, client, project_id: Cuid) -> 'LabelingService': return cls.get(client, project_id) @classmethod - def get(cls, client, project_id: Cuid) -> 'LabelingService': + def get(cls, client, project_id: Cuid) -> "LabelingService": """ Returns the labeling service associated with the project. @@ -74,14 +76,15 @@ def get(cls, client, project_id: Cuid) -> 'LabelingService': result = client.execute(query, {"projectId": project_id}) if result["projectBoostWorkforce"] is None: raise ResourceNotFoundError( - message="The project does not have a labeling service.") + message="The project does not have a labeling service." + ) data = result["projectBoostWorkforce"] data["client"] = client return LabelingService(**data) - def request(self) -> 'LabelingService': + def request(self) -> "LabelingService": """ - Creates a request to labeling service to start labeling for the project. + Creates a request to labeling service to start labeling for the project. Our back end will validate that the project is ready for labeling and then request the labeling service. Returns: @@ -100,15 +103,18 @@ def request(self) -> 'LabelingService': } } """ - result = self.client.execute(query_str, {"projectId": self.project_id}, - raise_return_resource_not_found=True) + result = self.client.execute( + query_str, + {"projectId": self.project_id}, + raise_return_resource_not_found=True, + ) success = result["validateAndRequestProjectBoostWorkforce"]["success"] if not success: raise Exception("Failed to start labeling service") return LabelingService.get(self.client, self.project_id) @classmethod - def getOrCreate(cls, client, project_id: Cuid) -> 'LabelingService': + def getOrCreate(cls, client, project_id: Cuid) -> "LabelingService": """ Returns the labeling service associated with the project. If the project does not have a labeling service, it will create one. @@ -127,4 +133,4 @@ def dashboard(self) -> LabelingServiceDashboard: Raises: ResourceNotFoundError: If the project does not have a labeling service. """ - return LabelingServiceDashboard.get(self.client, self.project_id) \ No newline at end of file + return LabelingServiceDashboard.get(self.client, self.project_id) diff --git a/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py b/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py index 41ce1f4d1..10a956a66 100644 --- a/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py +++ b/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py @@ -64,6 +64,7 @@ class LabelingServiceDashboard(_CamelCaseMixin): editor_task_type (EditorTaskType): editor task type of the project client (Any): labelbox client """ + id: str = Field(frozen=True) name: str = Field(frozen=True) created_at: Optional[datetime] = Field(frozen=True, default=None) @@ -83,7 +84,8 @@ def __init__(self, **kwargs): super().__init__(**kwargs) if not self.client.enable_experimental: raise RuntimeError( - "Please enable experimental in client to use LabelingService") + "Please enable experimental in client to use LabelingService" + ) @property def service_type(self): @@ -96,22 +98,34 @@ def service_type(self): if self.editor_task_type is None: return sentence_case(self.media_type.value) - if self.editor_task_type == EditorTaskType.OfflineModelChatEvaluation and self.media_type == MediaType.Conversational: + if ( + self.editor_task_type == EditorTaskType.OfflineModelChatEvaluation + and self.media_type == MediaType.Conversational + ): return "Offline chat evaluation" - if self.editor_task_type == EditorTaskType.ModelChatEvaluation and self.media_type == MediaType.Conversational: + if ( + self.editor_task_type == EditorTaskType.ModelChatEvaluation + and self.media_type == MediaType.Conversational + ): return "Live chat evaluation" - if self.editor_task_type == EditorTaskType.ResponseCreation and self.media_type == MediaType.Text: + if ( + self.editor_task_type == EditorTaskType.ResponseCreation + and self.media_type == MediaType.Text + ): return "Response creation" - if self.media_type == MediaType.LLMPromptCreation or self.media_type == MediaType.LLMPromptResponseCreation: + if ( + self.media_type == MediaType.LLMPromptCreation + or self.media_type == MediaType.LLMPromptResponseCreation + ): return "Prompt response creation" return sentence_case(self.media_type.value) @classmethod - def get(cls, client, project_id: str) -> 'LabelingServiceDashboard': + def get(cls, client, project_id: str) -> "LabelingServiceDashboard": """ Returns the labeling service associated with the project. @@ -140,7 +154,6 @@ def get_all( client, search_query: Optional[List[SearchFilter]] = None, ) -> PaginatedCollection: - if search_query is not None: template = Template( """query SearchProjectsPyApi($$first: Int, $$from: String) { @@ -150,7 +163,8 @@ def get_all( pageInfo { endCursor } } } - """) + """ + ) else: template = Template( """query SearchProjectsPyApi($$first: Int, $$from: String) { @@ -160,46 +174,48 @@ def get_all( pageInfo { endCursor } } } - """) + """ + ) query_str = template.substitute( labeling_dashboard_selections=GRAPHQL_QUERY_SELECTIONS, search_query=build_search_filter(search_query) - if search_query else None, + if search_query + else None, ) params: Dict[str, Union[str, int]] = {} def convert_to_labeling_service_dashboard(client, data): - data['client'] = client + data["client"] = client return LabelingServiceDashboard(**data) return PaginatedCollection( client=client, query=query_str, params=params, - dereferencing=['searchProjects', 'nodes'], + dereferencing=["searchProjects", "nodes"], obj_class=convert_to_labeling_service_dashboard, - cursor_path=['searchProjects', 'pageInfo', 'endCursor'], + cursor_path=["searchProjects", "pageInfo", "endCursor"], experimental=True, ) @root_validator(pre=True) def convert_boost_data(cls, data): - if 'boostStatus' in data: - data['status'] = LabelingServiceStatus(data.pop('boostStatus')) + if "boostStatus" in data: + data["status"] = LabelingServiceStatus(data.pop("boostStatus")) - if 'boostRequestedAt' in data: - data['created_at'] = data.pop('boostRequestedAt') + if "boostRequestedAt" in data: + data["created_at"] = data.pop("boostRequestedAt") - if 'boostUpdatedAt' in data: - data['updated_at'] = data.pop('boostUpdatedAt') + if "boostUpdatedAt" in data: + data["updated_at"] = data.pop("boostUpdatedAt") - if 'boostRequestedBy' in data: - data['created_by_id'] = data.pop('boostRequestedBy') + if "boostRequestedBy" in data: + data["created_by_id"] = data.pop("boostRequestedBy") return data def dict(self, *args, **kwargs): row = super().dict(*args, **kwargs) - row.pop('client') - row['service_type'] = self.service_type + row.pop("client") + row["service_type"] = self.service_type return row diff --git a/libs/labelbox/src/labelbox/schema/labeling_service_status.py b/libs/labelbox/src/labelbox/schema/labeling_service_status.py index 62cfd938e..c15cf73b9 100644 --- a/libs/labelbox/src/labelbox/schema/labeling_service_status.py +++ b/libs/labelbox/src/labelbox/schema/labeling_service_status.py @@ -2,12 +2,12 @@ class LabelingServiceStatus(Enum): - Accepted = 'ACCEPTED' - Calibration = 'CALIBRATION' - Complete = 'COMPLETE' - Production = 'PRODUCTION' - Requested = 'REQUESTED' - SetUp = 'SET_UP' + Accepted = "ACCEPTED" + Calibration = "CALIBRATION" + Complete = "COMPLETE" + Production = "PRODUCTION" + Requested = "REQUESTED" + SetUp = "SET_UP" Missing = None @classmethod @@ -15,10 +15,10 @@ def is_supported(cls, value): return isinstance(value, cls) @classmethod - def _missing_(cls, value) -> 'LabelingServiceStatus': + def _missing_(cls, value) -> "LabelingServiceStatus": """Handle missing null new task types - Handle upper case names for compatibility with - the GraphQL""" + Handle upper case names for compatibility with + the GraphQL""" if value is None: return cls.Missing diff --git a/libs/labelbox/src/labelbox/schema/media_type.py b/libs/labelbox/src/labelbox/schema/media_type.py index 99807522b..ae0bbbb3f 100644 --- a/libs/labelbox/src/labelbox/schema/media_type.py +++ b/libs/labelbox/src/labelbox/schema/media_type.py @@ -27,9 +27,9 @@ class MediaType(Enum): @classmethod def _missing_(cls, value): """Handle missing null data types for projects - created without setting allowedMediaType - Handle upper case names for compatibility with - the GraphQL""" + created without setting allowedMediaType + Handle upper case names for compatibility with + the GraphQL""" if value is None: return cls.Unknown @@ -46,9 +46,11 @@ def matches(value, name): value_underscore = value.replace("-", "_") camel_case_value = camel_case(value_underscore) - return (value_upper == name_upper or - value_underscore.upper() == name_upper or - camel_case_value.upper() == name_upper) + return ( + value_upper == name_upper + or value_underscore.upper() == name_upper + or camel_case_value.upper() == name_upper + ) for name, member in cls.__members__.items(): if matches(value, name): @@ -58,18 +60,23 @@ def matches(value, name): @classmethod def is_supported(cls, value): - return isinstance(value, - cls) and value not in [cls.Unknown, cls.Unsupported] + return isinstance(value, cls) and value not in [ + cls.Unknown, + cls.Unsupported, + ] @classmethod def get_supported_members(cls): return [ - item for item in cls.__members__ + item + for item in cls.__members__ if item not in ["Unknown", "Unsupported"] ] def get_media_type_validation_error(media_type): - return TypeError(f"{media_type} is not a valid media type. Use" - f" any of {MediaType.get_supported_members()}" - " from MediaType. Example: MediaType.Image.") + return TypeError( + f"{media_type} is not a valid media type. Use" + f" any of {MediaType.get_supported_members()}" + " from MediaType. Example: MediaType.Image." + ) diff --git a/libs/labelbox/src/labelbox/schema/model.py b/libs/labelbox/src/labelbox/schema/model.py index 692f43fad..e78620002 100644 --- a/libs/labelbox/src/labelbox/schema/model.py +++ b/libs/labelbox/src/labelbox/schema/model.py @@ -9,18 +9,18 @@ class Model(DbObject): """A model represents a program that has been trained and - can make predictions on new data. - Attributes: - name (str) - model_runs (Relationship): `ToMany` relationship to ModelRun - """ + can make predictions on new data. + Attributes: + name (str) + model_runs (Relationship): `ToMany` relationship to ModelRun + """ name = Field.String("name") ontology_id = Field.String("ontology_id") model_runs = Relationship.ToMany("ModelRun", False) def create_model_run(self, name, config=None) -> "ModelRun": - """ Creates a model run belonging to this model. + """Creates a model run belonging to this model. Args: name (string): The name for the model run. @@ -34,17 +34,22 @@ def create_model_run(self, name, config=None) -> "ModelRun": ModelRun = Entity.ModelRun query_str = """mutation CreateModelRunPyApi($%s: String!, $%s: Json, $%s: ID!) { createModelRun(data: {name: $%s, trainingMetadata: $%s, modelId: $%s}) {%s}}""" % ( - name_param, config_param, model_id_param, name_param, config_param, - model_id_param, query.results_query_part(ModelRun)) - res = self.client.execute(query_str, { - name_param: name, - config_param: config, - model_id_param: self.uid - }) + name_param, + config_param, + model_id_param, + name_param, + config_param, + model_id_param, + query.results_query_part(ModelRun), + ) + res = self.client.execute( + query_str, + {name_param: name, config_param: config, model_id_param: self.uid}, + ) return ModelRun(self.client, res["createModelRun"]) def delete(self) -> None: - """ Deletes specified model. + """Deletes specified model. Returns: Query execution success. diff --git a/libs/labelbox/src/labelbox/schema/model_config.py b/libs/labelbox/src/labelbox/schema/model_config.py index 46c0deca9..369315cd0 100644 --- a/libs/labelbox/src/labelbox/schema/model_config.py +++ b/libs/labelbox/src/labelbox/schema/model_config.py @@ -3,9 +3,9 @@ class ModelConfig(DbObject): - """ A ModelConfig represents a set of inference params configured for a model + """A ModelConfig represents a set of inference params configured for a model - Attributes: + Attributes: inference_params (JSON): Dict of inference params model_id (str): ID of the model to configure name (str): Name of config diff --git a/libs/labelbox/src/labelbox/schema/model_run.py b/libs/labelbox/src/labelbox/schema/model_run.py index 7f8714008..73c013b57 100644 --- a/libs/labelbox/src/labelbox/schema/model_run.py +++ b/libs/labelbox/src/labelbox/schema/model_run.py @@ -5,7 +5,16 @@ import warnings from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, Dict, Iterable, Union, Tuple, List, Optional, Any +from typing import ( + TYPE_CHECKING, + Dict, + Iterable, + Union, + Tuple, + List, + Optional, + Any, +) import requests @@ -14,12 +23,17 @@ from labelbox.orm.model import Field, Relationship, Entity from labelbox.orm.query import results_query_part from labelbox.pagination import PaginatedCollection -from labelbox.schema.conflict_resolution_strategy import ConflictResolutionStrategy +from labelbox.schema.conflict_resolution_strategy import ( + ConflictResolutionStrategy, +) from labelbox.schema.export_params import ModelRunExportParams from labelbox.schema.export_task import ExportTask from labelbox.schema.identifiables import UniqueIds, GlobalKeys, DataRowIds -from labelbox.schema.send_to_annotate_params import SendToAnnotateFromModelParams, build_destination_task_queue_input, \ - build_predictions_input +from labelbox.schema.send_to_annotate_params import ( + SendToAnnotateFromModelParams, + build_destination_task_queue_input, + build_predictions_input, +) from labelbox.schema.task import Task if TYPE_CHECKING: @@ -53,10 +67,12 @@ class Status(Enum): COMPLETE = "COMPLETE" FAILED = "FAILED" - def upsert_labels(self, - label_ids: Optional[List[str]] = None, - project_id: Optional[str] = None, - timeout_seconds=3600): + def upsert_labels( + self, + label_ids: Optional[List[str]] = None, + project_id: Optional[str] = None, + timeout_seconds=3600, + ): """ Adds data rows and labels to a Model Run @@ -75,7 +91,8 @@ def upsert_labels(self, if not use_label_ids and not use_project_id: raise ValueError( - "Must provide at least one label id or a project id") + "Must provide at least one label id or a project id" + ) if use_label_ids and use_project_id: raise ValueError("Must only one of label ids, project id") @@ -83,60 +100,64 @@ def upsert_labels(self, if use_label_ids: return self._upsert_labels_by_label_ids(label_ids, timeout_seconds) else: # use_project_id - return self._upsert_labels_by_project_id(project_id, - timeout_seconds) + return self._upsert_labels_by_project_id( + project_id, timeout_seconds + ) - def _upsert_labels_by_label_ids(self, label_ids: List[str], - timeout_seconds: int): - mutation_name = 'createMEAModelRunLabelRegistrationTask' + def _upsert_labels_by_label_ids( + self, label_ids: List[str], timeout_seconds: int + ): + mutation_name = "createMEAModelRunLabelRegistrationTask" create_task_query_str = """mutation createMEAModelRunLabelRegistrationTaskPyApi($modelRunId: ID!, $labelIds : [ID!]!) { %s(where : { id : $modelRunId}, data : {labelIds: $labelIds})} """ % (mutation_name) - res = self.client.execute(create_task_query_str, { - 'modelRunId': self.uid, - 'labelIds': label_ids - }) + res = self.client.execute( + create_task_query_str, + {"modelRunId": self.uid, "labelIds": label_ids}, + ) task_id = res[mutation_name] status_query_str = """query MEALabelRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){ MEALabelRegistrationTaskStatus(where: $where) {status errorMessage} } """ - return self._wait_until_done(lambda: self.client.execute( - status_query_str, {'where': { - 'id': task_id - }})['MEALabelRegistrationTaskStatus'], - timeout_seconds=timeout_seconds) - - def _upsert_labels_by_project_id(self, project_id: str, - timeout_seconds: int): - mutation_name = 'createMEAModelRunProjectLabelRegistrationTask' + return self._wait_until_done( + lambda: self.client.execute( + status_query_str, {"where": {"id": task_id}} + )["MEALabelRegistrationTaskStatus"], + timeout_seconds=timeout_seconds, + ) + + def _upsert_labels_by_project_id( + self, project_id: str, timeout_seconds: int + ): + mutation_name = "createMEAModelRunProjectLabelRegistrationTask" create_task_query_str = """mutation createMEAModelRunProjectLabelRegistrationTaskPyApi($modelRunId: ID!, $projectId : ID!) { %s(where : { modelRunId : $modelRunId, projectId: $projectId})} """ % (mutation_name) - res = self.client.execute(create_task_query_str, { - 'modelRunId': self.uid, - 'projectId': project_id - }) + res = self.client.execute( + create_task_query_str, + {"modelRunId": self.uid, "projectId": project_id}, + ) task_id = res[mutation_name] status_query_str = """query MEALabelRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){ MEALabelRegistrationTaskStatus(where: $where) {status errorMessage} } """ - return self._wait_until_done(lambda: self.client.execute( - status_query_str, {'where': { - 'id': task_id - }})['MEALabelRegistrationTaskStatus'], - timeout_seconds=timeout_seconds) - - def upsert_data_rows(self, - data_row_ids=None, - global_keys=None, - timeout_seconds=3600): - """ Adds data rows to a Model Run without any associated labels + return self._wait_until_done( + lambda: self.client.execute( + status_query_str, {"where": {"id": task_id}} + )["MEALabelRegistrationTaskStatus"], + timeout_seconds=timeout_seconds, + ) + + def upsert_data_rows( + self, data_row_ids=None, global_keys=None, timeout_seconds=3600 + ): + """Adds data rows to a Model Run without any associated labels Args: data_row_ids (list): data row ids to add to model run global_keys (list): global keys for data rows to add to model run @@ -145,37 +166,40 @@ def upsert_data_rows(self, ID of newly generated async task """ - mutation_name = 'createMEAModelRunDataRowRegistrationTask' + mutation_name = "createMEAModelRunDataRowRegistrationTask" create_task_query_str = """mutation createMEAModelRunDataRowRegistrationTaskPyApi($modelRunId: ID!, $dataRowIds: [ID!], $globalKeys: [ID!]) { %s(where : { id : $modelRunId}, data : {dataRowIds: $dataRowIds, globalKeys: $globalKeys})} """ % (mutation_name) res = self.client.execute( - create_task_query_str, { - 'modelRunId': self.uid, - 'dataRowIds': data_row_ids, - 'globalKeys': global_keys - }) + create_task_query_str, + { + "modelRunId": self.uid, + "dataRowIds": data_row_ids, + "globalKeys": global_keys, + }, + ) task_id = res[mutation_name] status_query_str = """query MEADataRowRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){ MEADataRowRegistrationTaskStatus(where: $where) {status errorMessage} } """ - return self._wait_until_done(lambda: self.client.execute( - status_query_str, {'where': { - 'id': task_id - }})['MEADataRowRegistrationTaskStatus'], - timeout_seconds=timeout_seconds) + return self._wait_until_done( + lambda: self.client.execute( + status_query_str, {"where": {"id": task_id}} + )["MEADataRowRegistrationTaskStatus"], + timeout_seconds=timeout_seconds, + ) def _wait_until_done(self, status_fn, timeout_seconds=120, sleep_time=5): # Do not use this function outside of the scope of upsert_data_rows or upsert_labels. It could change. original_timeout = timeout_seconds while True: res = status_fn() - if res['status'] == 'COMPLETE': + if res["status"] == "COMPLETE": return True - elif res['status'] == 'FAILED': + elif res["status"] == "FAILED": raise Exception(f"Job failed.") timeout_seconds -= sleep_time if timeout_seconds <= 0: @@ -190,7 +214,7 @@ def upsert_predictions_and_send_to_project( predictions: Union[str, Path, Iterable[Dict]], project_id: str, priority: Optional[int] = 5, - ) -> 'MEAPredictionImport': # type: ignore + ) -> "MEAPredictionImport": # type: ignore """ Provides a convenient way to execute the following steps in a single function call: 1. Upload predictions to a Model @@ -230,11 +254,14 @@ def upsert_predictions_and_send_to_project( import_job = self.add_predictions(name, predictions) prediction_statuses = import_job.statuses mea_to_mal_data_rows = list( - set([ - row['dataRow']['id'] - for row in prediction_statuses - if row['status'] == 'SUCCESS' - ])) + set( + [ + row["dataRow"]["id"] + for row in prediction_statuses + if row["status"] == "SUCCESS" + ] + ) + ) if not mea_to_mal_data_rows: # 0 successful model predictions imported @@ -254,10 +281,13 @@ def upsert_predictions_and_send_to_project( return import_job, None, None try: - mal_prediction_import = Entity.MEAToMALPredictionImport.create_for_model_run_data_rows( - data_row_ids=mea_to_mal_data_rows, - project_id=project_id, - **kwargs) + mal_prediction_import = ( + Entity.MEAToMALPredictionImport.create_for_model_run_data_rows( + data_row_ids=mea_to_mal_data_rows, + project_id=project_id, + **kwargs, + ) + ) mal_prediction_import.wait_until_done() except Exception as e: logger.warning( @@ -272,7 +302,7 @@ def add_predictions( self, name: str, predictions: Union[str, Path, Iterable[Dict], Iterable["Label"]], - ) -> 'MEAPredictionImport': # type: ignore + ) -> "MEAPredictionImport": # type: ignore """ Uploads predictions to a new Editor project. @@ -289,17 +319,21 @@ def add_predictions( kwargs = dict(client=self.client, id=self.uid, name=name) if isinstance(predictions, str) or isinstance(predictions, Path): if os.path.exists(predictions): - return Entity.MEAPredictionImport.create(path=str(predictions), - **kwargs) + return Entity.MEAPredictionImport.create( + path=str(predictions), **kwargs + ) else: - return Entity.MEAPredictionImport.create(url=str(predictions), - **kwargs) + return Entity.MEAPredictionImport.create( + url=str(predictions), **kwargs + ) elif isinstance(predictions, Iterable): - return Entity.MEAPredictionImport.create(labels=predictions, - **kwargs) + return Entity.MEAPredictionImport.create( + labels=predictions, **kwargs + ) else: raise ValueError( - f'Invalid predictions given of type: {type(predictions)}') + f"Invalid predictions given of type: {type(predictions)}" + ) def model_run_data_rows(self): query_str = """query modelRunPyApi($modelRunId: ID!, $from : String, $first: Int){ @@ -308,13 +342,16 @@ def model_run_data_rows(self): } """ % (results_query_part(ModelRunDataRow)) return PaginatedCollection( - self.client, query_str, {'modelRunId': self.uid}, - ['annotationGroups', 'nodes'], + self.client, + query_str, + {"modelRunId": self.uid}, + ["annotationGroups", "nodes"], lambda client, res: ModelRunDataRow(client, self.model_id, res), - ['annotationGroups', 'pageInfo', 'endCursor']) + ["annotationGroups", "pageInfo", "endCursor"], + ) def delete(self): - """ Deletes specified Model Run. + """Deletes specified Model Run. Returns: Query execution success. @@ -325,7 +362,7 @@ def delete(self): self.client.execute(query_str, {ids_param: str(self.uid)}) def delete_model_run_data_rows(self, data_row_ids: List[str]): - """ Deletes data rows from Model Runs. + """Deletes data rows from Model Runs. Args: data_row_ids (list): List of data row ids to delete from the Model Run. @@ -336,136 +373,150 @@ def delete_model_run_data_rows(self, data_row_ids: List[str]): data_row_ids_param = "dataRowIds" query_str = """mutation DeleteModelRunDataRowsPyApi($%s: ID!, $%s: [ID!]!) { deleteModelRunDataRows(where: {modelRunId: $%s, dataRowIds: $%s})}""" % ( - model_run_id_param, data_row_ids_param, model_run_id_param, - data_row_ids_param) - self.client.execute(query_str, { - model_run_id_param: self.uid, - data_row_ids_param: data_row_ids - }) + model_run_id_param, + data_row_ids_param, + model_run_id_param, + data_row_ids_param, + ) + self.client.execute( + query_str, + {model_run_id_param: self.uid, data_row_ids_param: data_row_ids}, + ) @experimental - def assign_data_rows_to_split(self, - data_row_ids: List[str] = None, - split: Union[DataSplit, str] = None, - global_keys: List[str] = None, - timeout_seconds=120): - + def assign_data_rows_to_split( + self, + data_row_ids: List[str] = None, + split: Union[DataSplit, str] = None, + global_keys: List[str] = None, + timeout_seconds=120, + ): split_value = split.value if isinstance(split, DataSplit) else split valid_splits = DataSplit._member_names_ if split_value is None or split_value not in valid_splits: raise ValueError( - f"`split` must be one of : `{valid_splits}`. Found : `{split}`") + f"`split` must be one of : `{valid_splits}`. Found : `{split}`" + ) task_id = self.client.execute( """mutation assignDataSplitPyApi($modelRunId: ID!, $data: CreateAssignDataRowsToDataSplitTaskInput!){ createAssignDataRowsToDataSplitTask(modelRun : {id: $modelRunId}, data: $data)} - """, { - 'modelRunId': self.uid, - 'data': { - 'assignments': [{ - 'split': split_value, - 'dataRowIds': data_row_ids, - 'globalKeys': global_keys, - }] - } + """, + { + "modelRunId": self.uid, + "data": { + "assignments": [ + { + "split": split_value, + "dataRowIds": data_row_ids, + "globalKeys": global_keys, + } + ] + }, }, - experimental=True)['createAssignDataRowsToDataSplitTask'] + experimental=True, + )["createAssignDataRowsToDataSplitTask"] status_query_str = """query assignDataRowsToDataSplitTaskStatusPyApi($id: ID!){ assignDataRowsToDataSplitTaskStatus(where: {id : $id}){status errorMessage}} """ - return self._wait_until_done(lambda: self.client.execute( - status_query_str, {'id': task_id}, experimental=True)[ - 'assignDataRowsToDataSplitTaskStatus'], - timeout_seconds=timeout_seconds) + return self._wait_until_done( + lambda: self.client.execute( + status_query_str, {"id": task_id}, experimental=True + )["assignDataRowsToDataSplitTaskStatus"], + timeout_seconds=timeout_seconds, + ) @experimental - def update_status(self, - status: Union[str, "ModelRun.Status"], - metadata: Optional[Dict[str, str]] = None, - error_message: Optional[str] = None): - - status_value = status.value if isinstance(status, - ModelRun.Status) else status + def update_status( + self, + status: Union[str, "ModelRun.Status"], + metadata: Optional[Dict[str, str]] = None, + error_message: Optional[str] = None, + ): + status_value = ( + status.value if isinstance(status, ModelRun.Status) else status + ) if status_value not in ModelRun.Status._member_names_: raise ValueError( f"Status must be one of : `{ModelRun.Status._member_names_}`. Found : `{status_value}`" ) - data: Dict[str, Any] = {'status': status_value} + data: Dict[str, Any] = {"status": status_value} if error_message: - data['errorMessage'] = error_message + data["errorMessage"] = error_message if metadata: - data['metadata'] = metadata + data["metadata"] = metadata self.client.execute( """mutation setPipelineStatusPyApi($modelRunId: ID!, $data: UpdateTrainingPipelineInput!){ updateTrainingPipeline(modelRun: {id : $modelRunId}, data: $data){status} } - """, { - 'modelRunId': self.uid, - 'data': data - }, - experimental=True) + """, + {"modelRunId": self.uid, "data": data}, + experimental=True, + ) @experimental def update_config(self, config: Dict[str, Any]) -> Dict[str, Any]: """ - Updates the Model Run's training metadata config - Args: - config (dict): A dictionary of keys and values - Returns: - Model Run id and updated training metadata - """ - data: Dict[str, Any] = {'config': config} + Updates the Model Run's training metadata config + Args: + config (dict): A dictionary of keys and values + Returns: + Model Run id and updated training metadata + """ + data: Dict[str, Any] = {"config": config} res = self.client.execute( """mutation updateModelRunConfigPyApi($modelRunId: ID!, $data: UpdateModelRunConfigInput!){ updateModelRunConfig(modelRun: {id : $modelRunId}, data: $data){trainingMetadata} } - """, { - 'modelRunId': self.uid, - 'data': data - }, - experimental=True) + """, + {"modelRunId": self.uid, "data": data}, + experimental=True, + ) return res["updateModelRunConfig"] @experimental def reset_config(self) -> Dict[str, Any]: """ - Resets Model Run's training metadata config - Returns: - Model Run id and reset training metadata - """ + Resets Model Run's training metadata config + Returns: + Model Run id and reset training metadata + """ res = self.client.execute( """mutation resetModelRunConfigPyApi($modelRunId: ID!){ resetModelRunConfig(modelRun: {id : $modelRunId}){trainingMetadata} } - """, {'modelRunId': self.uid}, - experimental=True) + """, + {"modelRunId": self.uid}, + experimental=True, + ) return res["resetModelRunConfig"] @experimental def get_config(self) -> Dict[str, Any]: """ - Gets Model Run's training metadata - Returns: - training metadata as a dictionary - """ - res = self.client.execute("""query ModelRunPyApi($modelRunId: ID!){ + Gets Model Run's training metadata + Returns: + training metadata as a dictionary + """ + res = self.client.execute( + """query ModelRunPyApi($modelRunId: ID!){ modelRun(where: {id : $modelRunId}){trainingMetadata} } - """, {'modelRunId': self.uid}, - experimental=True) + """, + {"modelRunId": self.uid}, + experimental=True, + ) return res["modelRun"]["trainingMetadata"] @experimental def export_labels( - self, - download: bool = False, - timeout_seconds: int = 600 + self, download: bool = False, timeout_seconds: int = 600 ) -> Optional[Union[str, List[Dict[Any, Any]]]]: """ Experimental. To use, make sure client has enable_experimental=True. @@ -482,7 +533,8 @@ def export_labels( """ warnings.warn( "You are currently utilizing exports v1 for this action, which will be deprecated after April 30th, 2024. We recommend transitioning to exports v2. To view export v2 details, visit our docs: https://docs.labelbox.com/reference/label-export", - DeprecationWarning) + DeprecationWarning, + ) sleep_time = 2 query_str = """mutation exportModelRunAnnotationsPyApi($modelRunId: ID!) { exportModelRunAnnotations(data: {modelRunId: $modelRunId}) { @@ -493,8 +545,8 @@ def export_labels( while True: url = self.client.execute( - query_str, {'modelRunId': self.uid}, - experimental=True)['exportModelRunAnnotations']['downloadUrl'] + query_str, {"modelRunId": self.uid}, experimental=True + )["exportModelRunAnnotations"]["downloadUrl"] if url: if not download: @@ -508,13 +560,16 @@ def export_labels( if timeout_seconds <= 0: return None - logger.debug("ModelRun '%s' label export, waiting for server...", - self.uid) + logger.debug( + "ModelRun '%s' label export, waiting for server...", self.uid + ) time.sleep(sleep_time) - def export(self, - task_name: Optional[str] = None, - params: Optional[ModelRunExportParams] = None) -> ExportTask: + def export( + self, + task_name: Optional[str] = None, + params: Optional[ModelRunExportParams] = None, + ) -> ExportTask: """ Creates a model run export task with the given params and returns the task. @@ -536,7 +591,7 @@ def export_v2( """ task, is_streamable = self._export(task_name, params) - if (is_streamable): + if is_streamable: return ExportTask(task, True) return task @@ -550,48 +605,50 @@ def _export( create_task_query_str = ( f"mutation {mutation_name}PyApi" f"($input: ExportDataRowsInModelRunInput!)" - f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}") + f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}" + ) _params = params or ModelRunExportParams() query_params = { "input": { "taskName": task_name, - "filters": { - "modelRunId": self.uid - }, + "filters": {"modelRunId": self.uid}, "isStreamableReady": True, "params": { - "mediaTypeOverride": - _params.get('media_type_override', None), - "includeAttachments": - _params.get('attachments', False), - "includeEmbeddings": - _params.get('embeddings', False), - "includeMetadata": - _params.get('metadata_fields', False), - "includeDataRowDetails": - _params.get('data_row_details', False), - "includePredictions": - _params.get('predictions', False), - "includeModelRunDetails": - _params.get('model_run_details', False), + "mediaTypeOverride": _params.get( + "media_type_override", None + ), + "includeAttachments": _params.get("attachments", False), + "includeEmbeddings": _params.get("embeddings", False), + "includeMetadata": _params.get("metadata_fields", False), + "includeDataRowDetails": _params.get( + "data_row_details", False + ), + "includePredictions": _params.get("predictions", False), + "includeModelRunDetails": _params.get( + "model_run_details", False + ), }, - "streamable": streamable + "streamable": streamable, } } - res = self.client.execute(create_task_query_str, - query_params, - error_log_key="errors") + res = self.client.execute( + create_task_query_str, query_params, error_log_key="errors" + ) res = res[mutation_name] task_id = res["taskId"] is_streamable = res["isStreamable"] return Task.get_task(self.client, task_id), is_streamable def send_to_annotate_from_model( - self, destination_project_id: str, task_queue_id: Optional[str], - batch_name: str, data_rows: Union[DataRowIds, GlobalKeys], - params: SendToAnnotateFromModelParams) -> Task: + self, + destination_project_id: str, + task_queue_id: Optional[str], + batch_name: str, + data_rows: Union[DataRowIds, GlobalKeys], + params: SendToAnnotateFromModelParams, + ) -> Task: """ Sends data rows from a model run to a project for annotation. @@ -625,46 +682,46 @@ def send_to_annotate_from_model( """ destination_task_queue = build_destination_task_queue_input( - task_queue_id) + task_queue_id + ) data_rows_query = self.client.build_catalog_query(data_rows) predictions_ontology_mapping = params.get( - "predictions_ontology_mapping", None) + "predictions_ontology_mapping", None + ) predictions_input = build_predictions_input( - predictions_ontology_mapping, self.uid) + predictions_ontology_mapping, self.uid + ) batch_priority = params.get("batch_priority", 5) exclude_data_rows_in_project = params.get( - "exclude_data_rows_in_project", False) + "exclude_data_rows_in_project", False + ) override_existing_annotations_rule = params.get( "override_existing_annotations_rule", - ConflictResolutionStrategy.KeepExisting) + ConflictResolutionStrategy.KeepExisting, + ) res = self.client.execute( - mutation_str, { + mutation_str, + { "input": { - "destinationProjectId": - destination_project_id, + "destinationProjectId": destination_project_id, "batchInput": { "batchName": batch_name, - "batchPriority": batch_priority + "batchPriority": batch_priority, }, - "destinationTaskQueue": - destination_task_queue, - "excludeDataRowsInProject": - exclude_data_rows_in_project, - "annotationsInput": - None, - "predictionsInput": - predictions_input, - "conflictLabelsResolutionStrategy": - override_existing_annotations_rule, + "destinationTaskQueue": destination_task_queue, + "excludeDataRowsInProject": exclude_data_rows_in_project, + "annotationsInput": None, + "predictionsInput": predictions_input, + "conflictLabelsResolutionStrategy": override_existing_annotations_rule, "searchQuery": [data_rows_query], - "sourceModelRunId": - self.uid + "sourceModelRunId": self.uid, } - })['sendToAnnotateFromMea'] + }, + )["sendToAnnotateFromMea"] - return Entity.Task.get_task(self.client, res['taskId']) + return Entity.Task.get_task(self.client, res["taskId"]) class ModelRunDataRow(DbObject): diff --git a/libs/labelbox/src/labelbox/schema/ontology.py b/libs/labelbox/src/labelbox/schema/ontology.py index 7b74acdc2..efe32611b 100644 --- a/libs/labelbox/src/labelbox/schema/ontology.py +++ b/libs/labelbox/src/labelbox/schema/ontology.py @@ -13,9 +13,12 @@ import json from pydantic import StringConstraints -FeatureSchemaId: Type[str] = Annotated[str, StringConstraints(min_length=25, - max_length=25)] -SchemaId: Type[str] = Annotated[str, StringConstraints(min_length=25, max_length=25)] +FeatureSchemaId: Type[str] = Annotated[ + str, StringConstraints(min_length=25, max_length=25) +] +SchemaId: Type[str] = Annotated[ + str, StringConstraints(min_length=25, max_length=25) +] class DeleteFeatureFromOntologyResult: @@ -23,8 +26,10 @@ class DeleteFeatureFromOntologyResult: deleted: bool def __str__(self): - return "<%s %s>" % (self.__class__.__name__.split(".")[-1], - json.dumps(self.__dict__)) + return "<%s %s>" % ( + self.__class__.__name__.split(".")[-1], + json.dumps(self.__dict__), + ) class FeatureSchema(DbObject): @@ -50,11 +55,14 @@ class Option: feature_schema_id: (str) options: (list) """ + value: Union[str, int] label: Optional[Union[str, int]] = None schema_id: Optional[str] = None feature_schema_id: Optional[FeatureSchemaId] = None - options: Union[List["Classification"], List["PromptResponseClassification"]] = field(default_factory=list) + options: Union[ + List["Classification"], List["PromptResponseClassification"] + ] = field(default_factory=list) def __post_init__(self): if self.label is None: @@ -62,17 +70,18 @@ def __post_init__(self): @classmethod def from_dict( - cls, - dictionary: Dict[str, - Any]) -> Dict[Union[str, int], Union[str, int]]: - return cls(value=dictionary["value"], - label=dictionary["label"], - schema_id=dictionary.get("schemaNodeId", None), - feature_schema_id=dictionary.get("featureSchemaId", None), - options=[ - Classification.from_dict(o) - for o in dictionary.get("options", []) - ]) + cls, dictionary: Dict[str, Any] + ) -> Dict[Union[str, int], Union[str, int]]: + return cls( + value=dictionary["value"], + label=dictionary["label"], + schema_id=dictionary.get("schemaNodeId", None), + feature_schema_id=dictionary.get("featureSchemaId", None), + options=[ + Classification.from_dict(o) + for o in dictionary.get("options", []) + ], + ) def asdict(self) -> Dict[str, Any]: return { @@ -80,20 +89,23 @@ def asdict(self) -> Dict[str, Any]: "featureSchemaId": self.feature_schema_id, "label": self.label, "value": self.value, - "options": [o.asdict(is_subclass=True) for o in self.options] + "options": [o.asdict(is_subclass=True) for o in self.options], } - def add_option(self, option: Union["Classification", "PromptResponseClassification"]) -> None: + def add_option( + self, option: Union["Classification", "PromptResponseClassification"] + ) -> None: if option.name in (o.name for o in self.options): raise InconsistentOntologyException( f"Duplicate nested classification '{option.name}' " - f"for option '{self.label}'") + f"for option '{self.label}'" + ) self.options.append(option) @dataclass class Classification: - """ + """ A classification to be added to a Project's ontology. The classification is dependent on the Classification Type. @@ -135,7 +147,7 @@ class Type(Enum): class Scope(Enum): GLOBAL = "global" INDEX = "index" - + class UIMode(Enum): HOTKEY = "hotkey" SEARCHABLE = "searchable" @@ -150,7 +162,9 @@ class UIMode(Enum): schema_id: Optional[str] = None feature_schema_id: Optional[str] = None scope: Scope = None - ui_mode: Optional[UIMode] = None # How this classification should be answered (e.g. hotkeys / autocomplete, etc) + ui_mode: Optional[UIMode] = ( + None # How this classification should be answered (e.g. hotkeys / autocomplete, etc) + ) def __post_init__(self): if self.name is None: @@ -159,7 +173,8 @@ def __post_init__(self): "for the classification schema name, which will be used when " "creating annotation payload for Model-Assisted Labeling " "Import and Label Import. “instructions” is no longer " - "supported to specify classification schema name.") + "supported to specify classification schema name." + ) if self.instructions is not None: self.name = self.instructions warnings.warn(msg) @@ -171,21 +186,25 @@ def __post_init__(self): @classmethod def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: - return cls(class_type=cls.Type(dictionary["type"]), - name=dictionary["name"], - instructions=dictionary["instructions"], - required=dictionary.get("required", False), - options=[Option.from_dict(o) for o in dictionary["options"]], - ui_mode=cls.UIMode(dictionary["uiMode"]) if "uiMode" in dictionary else None, - schema_id=dictionary.get("schemaNodeId", None), - feature_schema_id=dictionary.get("featureSchemaId", None), - scope=cls.Scope(dictionary.get("scope", cls.Scope.GLOBAL))) + return cls( + class_type=cls.Type(dictionary["type"]), + name=dictionary["name"], + instructions=dictionary["instructions"], + required=dictionary.get("required", False), + options=[Option.from_dict(o) for o in dictionary["options"]], + ui_mode=cls.UIMode(dictionary["uiMode"]) + if "uiMode" in dictionary + else None, + schema_id=dictionary.get("schemaNodeId", None), + feature_schema_id=dictionary.get("featureSchemaId", None), + scope=cls.Scope(dictionary.get("scope", cls.Scope.GLOBAL)), + ) def asdict(self, is_subclass: bool = False) -> Dict[str, Any]: - if self.class_type in self._REQUIRES_OPTIONS \ - and len(self.options) < 1: + if self.class_type in self._REQUIRES_OPTIONS and len(self.options) < 1: raise InconsistentOntologyException( - f"Classification '{self.name}' requires options.") + f"Classification '{self.name}' requires options." + ) classification = { "type": self.class_type.value, "instructions": self.instructions, @@ -193,24 +212,32 @@ def asdict(self, is_subclass: bool = False) -> Dict[str, Any]: "required": self.required, "options": [o.asdict() for o in self.options], "schemaNodeId": self.schema_id, - "featureSchemaId": self.feature_schema_id + "featureSchemaId": self.feature_schema_id, } - if (self.class_type == self.Type.RADIO or self.class_type == self.Type.CHECKLIST) and self.ui_mode: + if ( + self.class_type == self.Type.RADIO + or self.class_type == self.Type.CHECKLIST + ) and self.ui_mode: # added because this key does nothing for text so no point of including classification["uiMode"] = self.ui_mode.value if is_subclass: return classification - classification[ - "scope"] = self.scope.value if self.scope is not None else self.Scope.GLOBAL.value + classification["scope"] = ( + self.scope.value + if self.scope is not None + else self.Scope.GLOBAL.value + ) return classification def add_option(self, option: Option) -> None: if option.value in (o.value for o in self.options): raise InconsistentOntologyException( f"Duplicate option '{option.value}' " - f"for classification '{self.name}'.") + f"for classification '{self.name}'." + ) self.options.append(option) - + + @dataclass class ResponseOption(Option): """ @@ -228,26 +255,27 @@ class ResponseOption(Option): feature_schema_id: (str) options: (list) """ - + @classmethod def from_dict( - cls, - dictionary: Dict[str, - Any]) -> Dict[Union[str, int], Union[str, int]]: - return cls(value=dictionary["value"], - label=dictionary["label"], - schema_id=dictionary.get("schemaNodeId", None), - feature_schema_id=dictionary.get("featureSchemaId", None), - options=[ - PromptResponseClassification.from_dict(o) - for o in dictionary.get("options", []) - ]) + cls, dictionary: Dict[str, Any] + ) -> Dict[Union[str, int], Union[str, int]]: + return cls( + value=dictionary["value"], + label=dictionary["label"], + schema_id=dictionary.get("schemaNodeId", None), + feature_schema_id=dictionary.get("featureSchemaId", None), + options=[ + PromptResponseClassification.from_dict(o) + for o in dictionary.get("options", []) + ], + ) @dataclass class PromptResponseClassification: """ - + A PromptResponseClassification to be added to a Project's ontology. The classification is dependent on the PromptResponseClassification Type. @@ -268,7 +296,7 @@ class PromptResponseClassification: >>> classification_two = PromptResponseClassification( >>> class_type = PromptResponseClassification.Type.RESPONSE_RADIO, >>> name = "Second Example") - + >>> classification_two.add_option(ResponseOption( >>> value = "Option Example")) @@ -283,7 +311,7 @@ class PromptResponseClassification: schema_id: (str) feature_schema_id: (str) """ - + def __post_init__(self): if self.name is None: msg = ( @@ -291,7 +319,8 @@ def __post_init__(self): "for the classification schema name, which will be used when " "creating annotation payload for Model-Assisted Labeling " "Import and Label Import. “instructions” is no longer " - "supported to specify classification schema name.") + "supported to specify classification schema name." + ) if self.instructions is not None: self.name = self.instructions warnings.warn(msg) @@ -303,7 +332,7 @@ def __post_init__(self): class Type(Enum): PROMPT = "prompt" - RESPONSE_TEXT= "response-text" + RESPONSE_TEXT = "response-text" RESPONSE_CHECKLIST = "response-checklist" RESPONSE_RADIO = "response-radio" @@ -321,31 +350,38 @@ class Type(Enum): @classmethod def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: - return cls(class_type=cls.Type(dictionary["type"]), - name=dictionary["name"], - instructions=dictionary["instructions"], - required=True, # always required - options=[ResponseOption.from_dict(o) for o in dictionary["options"]], - character_min=dictionary.get("minCharacters", None), - character_max=dictionary.get("maxCharacters", None), - schema_id=dictionary.get("schemaNodeId", None), - feature_schema_id=dictionary.get("featureSchemaId", None)) + return cls( + class_type=cls.Type(dictionary["type"]), + name=dictionary["name"], + instructions=dictionary["instructions"], + required=True, # always required + options=[ + ResponseOption.from_dict(o) for o in dictionary["options"] + ], + character_min=dictionary.get("minCharacters", None), + character_max=dictionary.get("maxCharacters", None), + schema_id=dictionary.get("schemaNodeId", None), + feature_schema_id=dictionary.get("featureSchemaId", None), + ) def asdict(self, is_subclass: bool = False) -> Dict[str, Any]: - if self.class_type in self._REQUIRES_OPTIONS \ - and len(self.options) < 1: + if self.class_type in self._REQUIRES_OPTIONS and len(self.options) < 1: raise InconsistentOntologyException( - f"Response Classification '{self.name}' requires options.") + f"Response Classification '{self.name}' requires options." + ) classification = { "type": self.class_type.value, "instructions": self.instructions, "name": self.name, - "required": True, # always required + "required": True, # always required "options": [o.asdict() for o in self.options], "schemaNodeId": self.schema_id, - "featureSchemaId": self.feature_schema_id + "featureSchemaId": self.feature_schema_id, } - if (self.class_type == self.Type.PROMPT or self.class_type == self.Type.RESPONSE_TEXT): + if ( + self.class_type == self.Type.PROMPT + or self.class_type == self.Type.RESPONSE_TEXT + ): if self.character_min: classification["minCharacters"] = self.character_min if self.character_max: @@ -358,7 +394,8 @@ def add_option(self, option: ResponseOption) -> None: if option.value in (o.value for o in self.options): raise InconsistentOntologyException( f"Duplicate option '{option.value}' " - f"for response classification '{self.name}'.") + f"for response classification '{self.name}'." + ) self.options.append(option) @@ -402,9 +439,9 @@ class Type(Enum): LINE = "line" NER = "named-entity" RELATIONSHIP = "edge" - MESSAGE_SINGLE_SELECTION = 'message-single-selection' - MESSAGE_MULTI_SELECTION = 'message-multi-selection' - MESSAGE_RANKING = 'message-ranking' + MESSAGE_SINGLE_SELECTION = "message-single-selection" + MESSAGE_MULTI_SELECTION = "message-multi-selection" + MESSAGE_RANKING = "message-ranking" tool: Type name: str @@ -416,16 +453,18 @@ class Type(Enum): @classmethod def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: - return cls(name=dictionary['name'], - schema_id=dictionary.get("schemaNodeId", None), - feature_schema_id=dictionary.get("featureSchemaId", None), - required=dictionary.get("required", False), - tool=cls.Type(dictionary["tool"]), - classifications=[ - Classification.from_dict(c) - for c in dictionary["classifications"] - ], - color=dictionary["color"]) + return cls( + name=dictionary["name"], + schema_id=dictionary.get("schemaNodeId", None), + feature_schema_id=dictionary.get("featureSchemaId", None), + required=dictionary.get("required", False), + tool=cls.Type(dictionary["tool"]), + classifications=[ + Classification.from_dict(c) + for c in dictionary["classifications"] + ], + color=dictionary["color"], + ) def asdict(self) -> Dict[str, Any]: return { @@ -437,14 +476,15 @@ def asdict(self) -> Dict[str, Any]: c.asdict(is_subclass=True) for c in self.classifications ], "schemaNodeId": self.schema_id, - "featureSchemaId": self.feature_schema_id + "featureSchemaId": self.feature_schema_id, } def add_classification(self, classification: Classification) -> None: if classification.name in (c.name for c in self.classifications): raise InconsistentOntologyException( f"Duplicate nested classification '{classification.name}' " - f"for tool '{self.name}'") + f"for tool '{self.name}'" + ) self.classifications.append(classification) @@ -477,25 +517,37 @@ class Ontology(DbObject): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._tools: Optional[List[Tool]] = None - self._classifications: Optional[Union[List[Classification],List[PromptResponseClassification]]] = None + self._classifications: Optional[ + Union[List[Classification], List[PromptResponseClassification]] + ] = None def tools(self) -> List[Tool]: """Get list of tools (AKA objects) in an Ontology.""" if self._tools is None: self._tools = [ - Tool.from_dict(tool) for tool in self.normalized['tools'] + Tool.from_dict(tool) for tool in self.normalized["tools"] ] return self._tools - def classifications(self) -> List[Union[Classification, PromptResponseClassification]]: + def classifications( + self, + ) -> List[Union[Classification, PromptResponseClassification]]: """Get list of classifications in an Ontology.""" if self._classifications is None: self._classifications = [] for classification in self.normalized["classifications"]: - if "type" in classification and classification["type"] in PromptResponseClassification.Type._value2member_map_.keys(): - self._classifications.append(PromptResponseClassification.from_dict(classification)) + if ( + "type" in classification + and classification["type"] + in PromptResponseClassification.Type._value2member_map_.keys() + ): + self._classifications.append( + PromptResponseClassification.from_dict(classification) + ) else: - self._classifications.append(Classification.from_dict(classification)) + self._classifications.append( + Classification.from_dict(classification) + ) return self._classifications @@ -524,36 +576,52 @@ class OntologyBuilder: """ + tools: List[Tool] = field(default_factory=list) - classifications: List[Union[Classification, PromptResponseClassification]] = field(default_factory=list) + classifications: List[ + Union[Classification, PromptResponseClassification] + ] = field(default_factory=list) @classmethod def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: classifications = [] for c in dictionary["classifications"]: - if "type" in c and c["type"] in PromptResponseClassification.Type._value2member_map_.keys(): - classifications.append(PromptResponseClassification.from_dict(c)) + if ( + "type" in c + and c["type"] + in PromptResponseClassification.Type._value2member_map_.keys() + ): + classifications.append( + PromptResponseClassification.from_dict(c) + ) else: classifications.append(Classification.from_dict(c)) - return cls(tools=[Tool.from_dict(t) for t in dictionary["tools"]], - classifications=classifications) + return cls( + tools=[Tool.from_dict(t) for t in dictionary["tools"]], + classifications=classifications, + ) def asdict(self) -> Dict[str, Any]: self._update_colors() classifications = [] prompts = 0 for c in self.classifications: - if hasattr(c, "class_type") and c.class_type in PromptResponseClassification.Type: + if ( + hasattr(c, "class_type") + and c.class_type in PromptResponseClassification.Type + ): if c.class_type == PromptResponseClassification.Type.PROMPT: prompts += 1 if prompts > 1: - raise ValueError("Only one prompt is allowed per ontology") + raise ValueError( + "Only one prompt is allowed per ontology" + ) classifications.append(PromptResponseClassification.asdict(c)) else: classifications.append(Classification.asdict(c)) return { "tools": [t.asdict() for t in self.tools], - "classifications": classifications + "classifications": classifications, } def _update_colors(self): @@ -562,9 +630,10 @@ def _update_colors(self): for index in range(num_tools): hsv_color = (index * 1 / num_tools, 1, 1) rgb_color = tuple( - int(255 * x) for x in colorsys.hsv_to_rgb(*hsv_color)) + int(255 * x) for x in colorsys.hsv_to_rgb(*hsv_color) + ) if self.tools[index].color is None: - self.tools[index].color = '#%02x%02x%02x' % rgb_color + self.tools[index].color = "#%02x%02x%02x" % rgb_color @classmethod def from_project(cls, project: "project.Project") -> "OntologyBuilder": @@ -578,11 +647,16 @@ def from_ontology(cls, ontology: Ontology) -> "OntologyBuilder": def add_tool(self, tool: Tool) -> None: if tool.name in (t.name for t in self.tools): raise InconsistentOntologyException( - f"Duplicate tool name '{tool.name}'. ") + f"Duplicate tool name '{tool.name}'. " + ) self.tools.append(tool) - def add_classification(self, classification: Union[Classification, PromptResponseClassification]) -> None: + def add_classification( + self, + classification: Union[Classification, PromptResponseClassification], + ) -> None: if classification.name in (c.name for c in self.classifications): raise InconsistentOntologyException( - f"Duplicate classification name '{classification.name}'. ") + f"Duplicate classification name '{classification.name}'. " + ) self.classifications.append(classification) diff --git a/libs/labelbox/src/labelbox/schema/ontology_kind.py b/libs/labelbox/src/labelbox/schema/ontology_kind.py index 7dd3311cb..3171b811e 100644 --- a/libs/labelbox/src/labelbox/schema/ontology_kind.py +++ b/libs/labelbox/src/labelbox/schema/ontology_kind.py @@ -8,6 +8,7 @@ class OntologyKind(Enum): """ OntologyKind is an enum that represents the different types of ontologies """ + ModelEvaluation = "MODEL_EVALUATION" ResponseCreation = "RESPONSE_CREATION" Missing = None @@ -18,27 +19,31 @@ def is_supported(cls, value): @classmethod def get_ontology_kind_validation_error(cls, ontology_kind): - return TypeError(f"{ontology_kind}: is not a valid ontology kind. Use" - f" any of {OntologyKind.__members__.items()}" - " from OntologyKind.") + return TypeError( + f"{ontology_kind}: is not a valid ontology kind. Use" + f" any of {OntologyKind.__members__.items()}" + " from OntologyKind." + ) @staticmethod def evaluate_ontology_kind_with_media_type( - ontology_kind, - media_type: Optional[MediaType]) -> Union[MediaType, None]: - + ontology_kind, media_type: Optional[MediaType] + ) -> Union[MediaType, None]: ontology_to_media = { - OntologyKind.ModelEvaluation: - (MediaType.Conversational, - "For chat evaluation, media_type must be Conversational."), - OntologyKind.ResponseCreation: - (MediaType.Text, - "For response creation, media_type must be Text.") + OntologyKind.ModelEvaluation: ( + MediaType.Conversational, + "For chat evaluation, media_type must be Conversational.", + ), + OntologyKind.ResponseCreation: ( + MediaType.Text, + "For response creation, media_type must be Text.", + ), } if ontology_kind in ontology_to_media: expected_media_type, error_message = ontology_to_media[ - ontology_kind] + ontology_kind + ] if media_type is None or media_type == expected_media_type: media_type = expected_media_type @@ -59,10 +64,10 @@ def is_supported(cls, value): return isinstance(value, cls) @classmethod - def _missing_(cls, value) -> 'EditorTaskType': + def _missing_(cls, value) -> "EditorTaskType": """Handle missing null new task types - Handle upper case names for compatibility with - the GraphQL""" + Handle upper case names for compatibility with + the GraphQL""" if value is None: return cls.Missing @@ -75,34 +80,45 @@ def _missing_(cls, value) -> 'EditorTaskType': class EditorTaskTypeMapper: - @staticmethod - def to_editor_task_type(ontology_kind: OntologyKind, - media_type: MediaType) -> EditorTaskType: - if ontology_kind and OntologyKind.is_supported( - ontology_kind) and media_type and MediaType.is_supported( - media_type): + def to_editor_task_type( + ontology_kind: OntologyKind, media_type: MediaType + ) -> EditorTaskType: + if ( + ontology_kind + and OntologyKind.is_supported(ontology_kind) + and media_type + and MediaType.is_supported(media_type) + ): editor_task_type = EditorTaskTypeMapper.map_to_editor_task_type( - ontology_kind, media_type) + ontology_kind, media_type + ) else: editor_task_type = EditorTaskType.Missing return editor_task_type @staticmethod - def map_to_editor_task_type(onotology_kind: OntologyKind, - media_type: MediaType) -> EditorTaskType: - if onotology_kind == OntologyKind.ModelEvaluation and media_type == MediaType.Conversational: + def map_to_editor_task_type( + onotology_kind: OntologyKind, media_type: MediaType + ) -> EditorTaskType: + if ( + onotology_kind == OntologyKind.ModelEvaluation + and media_type == MediaType.Conversational + ): return EditorTaskType.ModelChatEvaluation - elif onotology_kind == OntologyKind.ResponseCreation and media_type == MediaType.Text: + elif ( + onotology_kind == OntologyKind.ResponseCreation + and media_type == MediaType.Text + ): return EditorTaskType.ResponseCreation else: return EditorTaskType.Missing class UploadType(Enum): - Auto = 'AUTO', - Manual = 'MANUAL', + Auto = ("AUTO",) + Manual = ("MANUAL",) Missing = None @classmethod @@ -110,7 +126,7 @@ def is_supported(cls, value): return isinstance(value, cls) @classmethod - def _missing_(cls, value: object) -> 'UploadType': + def _missing_(cls, value: object) -> "UploadType": if value is None: return cls.Missing diff --git a/libs/labelbox/src/labelbox/schema/organization.py b/libs/labelbox/src/labelbox/schema/organization.py index 3a5e23efc..71e715f11 100644 --- a/libs/labelbox/src/labelbox/schema/organization.py +++ b/libs/labelbox/src/labelbox/schema/organization.py @@ -9,11 +9,18 @@ from labelbox.schema.resource_tag import ResourceTag if TYPE_CHECKING: - from labelbox import Role, User, ProjectRole, Invite, InviteLimit, IAMIntegration + from labelbox import ( + Role, + User, + ProjectRole, + Invite, + InviteLimit, + IAMIntegration, + ) class Organization(DbObject): - """ An Organization is a group of Users. + """An Organization is a group of Users. It is associated with data created by Users within that Organization. Typically all Users within an Organization have access to data created by any User in the same Organization. @@ -47,10 +54,11 @@ def __init__(self, *args, **kwargs): resource_tags = Relationship.ToMany("ResourceTags", False) def invite_user( - self, - email: str, - role: "Role", - project_roles: Optional[List["ProjectRole"]] = None) -> "Invite": + self, + email: str, + role: "Role", + project_roles: Optional[List["ProjectRole"]] = None, + ) -> "Invite": """ Invite a new member to the org. This will send the user an email invite @@ -76,30 +84,40 @@ def invite_user( data_param = "data" query_str = """mutation createInvitesPyApi($%s: [CreateInviteInput!]){ createInvites(data: $%s){ invite { id createdAt organizationRoleName inviteeEmail inviter { %s } }}}""" % ( - data_param, data_param, query.results_query_part(Entity.User)) - - projects = [{ - "projectId": project_role.project.uid, - "projectRoleId": project_role.role.uid - } for project_role in project_roles or []] + data_param, + data_param, + query.results_query_part(Entity.User), + ) + + projects = [ + { + "projectId": project_role.project.uid, + "projectRoleId": project_role.role.uid, + } + for project_role in project_roles or [] + ] res = self.client.execute( - query_str, { - data_param: [{ - "inviterId": self.client.get_user().uid, - "inviteeEmail": email, - "organizationId": self.uid, - "organizationRoleId": role.uid, - "projects": projects - }] - }) - invite_response = res['createInvites'][0]['invite'] + query_str, + { + data_param: [ + { + "inviterId": self.client.get_user().uid, + "inviteeEmail": email, + "organizationId": self.uid, + "organizationRoleId": role.uid, + "projects": projects, + } + ] + }, + ) + invite_response = res["createInvites"][0]["invite"] if not invite_response: raise LabelboxError(f"Unable to send invite for email {email}") return Entity.Invite(self.client, invite_response) def invite_limit(self) -> InviteLimit: - """ Retrieve invite limits for the org + """Retrieve invite limits for the org This already accounts for users currently in the org Meaining that `used = users + invites, remaining = limit - (users + invites)` @@ -111,10 +129,13 @@ def invite_limit(self) -> InviteLimit: res = self.client.execute( """query InvitesLimitPyApi($%s: ID!) { invitesLimit(where: {id: $%s}) { used limit remaining } - }""" % (org_id_param, org_id_param), {org_id_param: self.uid}) - return InviteLimit(**{ - utils.snake_case(k): v for k, v in res['invitesLimit'].items() - }) + }""" + % (org_id_param, org_id_param), + {org_id_param: self.uid}, + ) + return InviteLimit( + **{utils.snake_case(k): v for k, v in res["invitesLimit"].items()} + ) def remove_user(self, user: "User") -> None: """ @@ -128,7 +149,10 @@ def remove_user(self, user: "User") -> None: self.client.execute( """mutation DeleteMemberPyApi($%s: ID!) { updateUser(where: {id: $%s}, data: {deleted: true}) { id deleted } - }""" % (user_id_param, user_id_param), {user_id_param: user.uid}) + }""" + % (user_id_param, user_id_param), + {user_id_param: user.uid}, + ) def create_resource_tag(self, tag: Dict[str, str]) -> ResourceTag: """ @@ -145,30 +169,38 @@ def create_resource_tag(self, tag: Dict[str, str]) -> ResourceTag: query_str = """mutation CreateResourceTagPyApi($text:String!,$color:String!) { createResourceTag(input:{text:$%s,color:$%s}) {%s}} - """ % (tag_text_param, tag_color_param, - query.results_query_part(ResourceTag)) + """ % ( + tag_text_param, + tag_color_param, + query.results_query_part(ResourceTag), + ) params = { tag_text_param: tag.get("text", None), - tag_color_param: tag.get("color", None) + tag_color_param: tag.get("color", None), } if not all(params.values()): raise ValueError( - f"tag must contain 'text' and 'color' keys. received: {tag}") + f"tag must contain 'text' and 'color' keys. received: {tag}" + ) res = self.client.execute(query_str, params) - return ResourceTag(self.client, res['createResourceTag']) + return ResourceTag(self.client, res["createResourceTag"]) def get_resource_tags(self) -> List[ResourceTag]: """ Returns all resource tags for an organization """ - query_str = """query GetOrganizationResourceTagsPyApi{organization{resourceTag{%s}}}""" % ( - query.results_query_part(ResourceTag)) + query_str = ( + """query GetOrganizationResourceTagsPyApi{organization{resourceTag{%s}}}""" + % (query.results_query_part(ResourceTag)) + ) return [ - ResourceTag(self.client, tag) for tag in self.client.execute( - query_str)['organization']['resourceTag'] + ResourceTag(self.client, tag) + for tag in self.client.execute(query_str)["organization"][ + "resourceTag" + ] ] def get_iam_integrations(self) -> List["IAMIntegration"]: @@ -184,10 +216,12 @@ def get_iam_integrations(self) -> List["IAMIntegration"]: ... on GcpIamIntegrationSettings {serviceAccountEmailId readBucket} } - } } """ % query.results_query_part(Entity.IAMIntegration)) + } } """ + % query.results_query_part(Entity.IAMIntegration) + ) return [ Entity.IAMIntegration(self.client, integration_data) - for integration_data in res['iamIntegrations'] + for integration_data in res["iamIntegrations"] ] def get_default_iam_integration(self) -> Optional["IAMIntegration"]: @@ -197,12 +231,14 @@ def get_default_iam_integration(self) -> Optional["IAMIntegration"]: """ integrations = self.get_iam_integrations() default_integration = [ - integration for integration in integrations + integration + for integration in integrations if integration.is_org_default ] if len(default_integration) > 1: raise ValueError( "Found more than one default signer. Please contact Labelbox to resolve" ) - return None if not len( - default_integration) else default_integration.pop() + return ( + None if not len(default_integration) else default_integration.pop() + ) diff --git a/libs/labelbox/src/labelbox/schema/project.py b/libs/labelbox/src/labelbox/schema/project.py index a30ff856b..a45ddfa4b 100644 --- a/libs/labelbox/src/labelbox/schema/project.py +++ b/libs/labelbox/src/labelbox/schema/project.py @@ -6,19 +6,37 @@ from collections import namedtuple from datetime import datetime, timezone from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, TypeVar, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, + Tuple, + TypeVar, + Union, + overload, +) from urllib.parse import urlparse -from labelbox.schema.labeling_service import LabelingService, LabelingServiceStatus +from labelbox.schema.labeling_service import ( + LabelingService, + LabelingServiceStatus, +) from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard import requests from labelbox import parser from labelbox import utils from labelbox.exceptions import error_message_for_unparsed_graphql_error -from labelbox.exceptions import (InvalidQueryError, LabelboxError, - ProcessingWaitTimeout, ResourceConflict, - ResourceNotFoundError) +from labelbox.exceptions import ( + InvalidQueryError, + LabelboxError, + ProcessingWaitTimeout, + ResourceConflict, + ResourceNotFoundError, +) from labelbox.orm import query from labelbox.orm.db_object import DbObject, Deletable, Updateable, experimental from labelbox.orm.model import Entity, Field, Relationship @@ -26,7 +44,11 @@ from labelbox.schema.consensus_settings import ConsensusSettings from labelbox.schema.create_batches_task import CreateBatchesTask from labelbox.schema.data_row import DataRow -from labelbox.schema.export_filters import ProjectExportFilters, validate_datetime, build_filters +from labelbox.schema.export_filters import ( + ProjectExportFilters, + validate_datetime, + build_filters, +) from labelbox.schema.export_params import ProjectExportParams from labelbox.schema.export_task import ExportTask from labelbox.schema.id_type import IdType @@ -39,24 +61,32 @@ from labelbox.schema.resource_tag import ResourceTag from labelbox.schema.task import Task from labelbox.schema.task_queue import TaskQueue -from labelbox.schema.ontology_kind import (EditorTaskType, OntologyKind, - UploadType) -from labelbox.schema.project_overview import ProjectOverview, ProjectOverviewDetailed +from labelbox.schema.ontology_kind import ( + EditorTaskType, + OntologyKind, + UploadType, +) +from labelbox.schema.project_overview import ( + ProjectOverview, + ProjectOverviewDetailed, +) if TYPE_CHECKING: from labelbox import BulkImportRequest DataRowPriority = int -LabelingParameterOverrideInput = Tuple[Union[DataRow, DataRowIdentifier], - DataRowPriority] +LabelingParameterOverrideInput = Tuple[ + Union[DataRow, DataRowIdentifier], DataRowPriority +] logger = logging.getLogger(__name__) MAX_SYNC_BATCH_ROW_COUNT = 1_000 def validate_labeling_parameter_overrides( - data: List[LabelingParameterOverrideInput]) -> None: + data: List[LabelingParameterOverrideInput], +) -> None: for idx, row in enumerate(data): if len(row) < 2: raise TypeError( @@ -131,11 +161,14 @@ class Project(DbObject, Updateable, Deletable): organization = Relationship.ToOne("Organization", False) labeling_frontend = Relationship.ToOne( "LabelingFrontend", - config=Relationship.Config(disconnect_supported=False)) + config=Relationship.Config(disconnect_supported=False), + ) labeling_frontend_options = Relationship.ToMany( - "LabelingFrontendOptions", False, "labeling_frontend_options") + "LabelingFrontendOptions", False, "labeling_frontend_options" + ) labeling_parameter_overrides = Relationship.ToMany( - "LabelingParameterOverride", False, "labeling_parameter_overrides") + "LabelingParameterOverride", False, "labeling_parameter_overrides" + ) webhooks = Relationship.ToMany("Webhook", False) benchmarks = Relationship.ToMany("Benchmark", False) ontology = Relationship.ToOne("Ontology", True) @@ -148,23 +181,31 @@ def is_chat_evaluation(self) -> bool: Returns: True if this project is a live chat evaluation project, False otherwise """ - return self.media_type == MediaType.Conversational and self.editor_task_type == EditorTaskType.ModelChatEvaluation + return ( + self.media_type == MediaType.Conversational + and self.editor_task_type == EditorTaskType.ModelChatEvaluation + ) def is_prompt_response(self) -> bool: """ Returns: True if this project is a prompt response project, False otherwise """ - return self.media_type == MediaType.LLMPromptResponseCreation or self.media_type == MediaType.LLMPromptCreation or self.editor_task_type == EditorTaskType.ResponseCreation + return ( + self.media_type == MediaType.LLMPromptResponseCreation + or self.media_type == MediaType.LLMPromptCreation + or self.editor_task_type == EditorTaskType.ResponseCreation + ) def is_auto_data_generation(self) -> bool: - return (self.upload_type == UploadType.Auto) # type: ignore + return self.upload_type == UploadType.Auto # type: ignore # we test not only the project ontology is None, but also a default empty ontology that we create when we attach a labeling front end in createLabelingFrontendOptions def is_empty_ontology(self) -> bool: ontology = self.ontology() # type: ignore - return ontology is None or (len(ontology.tools()) == 0 and - len(ontology.classifications()) == 0) + return ontology is None or ( + len(ontology.tools()) == 0 and len(ontology.classifications()) == 0 + ) def project_model_configs(self): query_str = """query ProjectModelConfigsPyApi($id: ID!) { @@ -189,7 +230,7 @@ def project_model_configs(self): ] def update(self, **kwargs): - """ Updates this project with the specified attributes + """Updates this project with the specified attributes Args: kwargs: a dictionary containing attributes to be upserted @@ -214,14 +255,16 @@ def update(self, **kwargs): if MediaType.is_supported(media_type): kwargs["media_type"] = media_type.value else: - raise TypeError(f"{media_type} is not a valid media type. Use" - f" any of {MediaType.get_supported_members()}" - " from MediaType. Example: MediaType.Image.") + raise TypeError( + f"{media_type} is not a valid media type. Use" + f" any of {MediaType.get_supported_members()}" + " from MediaType. Example: MediaType.Image." + ) return super().update(**kwargs) def members(self) -> PaginatedCollection: - """ Fetch all current members for this project + """Fetch all current members for this project Returns: A `PaginatedCollection` of `ProjectMember`s @@ -232,13 +275,18 @@ def members(self) -> PaginatedCollection: project(where: {id : $%s}) { id members(skip: %%d first: %%d){ id user { %s } role { id name } accessFrom } } }""" % (id_param, id_param, query.results_query_part(Entity.User)) - return PaginatedCollection(self.client, query_str, - {id_param: str(self.uid)}, - ["project", "members"], ProjectMember) + return PaginatedCollection( + self.client, + query_str, + {id_param: str(self.uid)}, + ["project", "members"], + ProjectMember, + ) def update_project_resource_tags( - self, resource_tag_ids: List[str]) -> List[ResourceTag]: - """ Creates project resource tags + self, resource_tag_ids: List[str] + ) -> List[ResourceTag]: + """Creates project resource tags Args: resource_tag_ids @@ -250,13 +298,18 @@ def update_project_resource_tags( query_str = """mutation UpdateProjectResourceTagsPyApi($%s:ID!,$%s:[String!]) { project(where:{id:$%s}){updateProjectResourceTags(input:{%s:$%s}){%s}}}""" % ( - project_id_param, tag_ids_param, project_id_param, tag_ids_param, - tag_ids_param, query.results_query_part(ResourceTag)) + project_id_param, + tag_ids_param, + project_id_param, + tag_ids_param, + tag_ids_param, + query.results_query_part(ResourceTag), + ) - res = self.client.execute(query_str, { - project_id_param: self.uid, - tag_ids_param: resource_tag_ids - }) + res = self.client.execute( + query_str, + {project_id_param: self.uid, tag_ids_param: resource_tag_ids}, + ) return [ ResourceTag(self.client, tag) @@ -274,13 +327,14 @@ def get_resource_tags(self) -> List[ResourceTag]: } }""" % (query.results_query_part(ResourceTag)) - results = self.client.execute( - query_str, {"projectId": self.uid})['project']['resourceTags'] + results = self.client.execute(query_str, {"projectId": self.uid})[ + "project" + ]["resourceTags"] return [ResourceTag(self.client, tag) for tag in results] def labels(self, datasets=None, order_by=None) -> PaginatedCollection: - """ Custom relationship expansion method to support limited filtering. + """Custom relationship expansion method to support limited filtering. Args: datasets (iterable of Dataset): Optional collection of Datasets @@ -292,14 +346,17 @@ def labels(self, datasets=None, order_by=None) -> PaginatedCollection: if datasets is not None: where = " where:{dataRow: {dataset: {id_in: [%s]}}}" % ", ".join( - '"%s"' % dataset.uid for dataset in datasets) + '"%s"' % dataset.uid for dataset in datasets + ) else: where = "" if order_by is not None: query.check_order_by_clause(Label, order_by) - order_by_str = "orderBy: %s_%s" % (order_by[0].graphql_name, - order_by[1].name.upper()) + order_by_str = "orderBy: %s_%s" % ( + order_by[0].graphql_name, + order_by[1].name.upper(), + ) else: order_by_str = "" @@ -307,17 +364,25 @@ def labels(self, datasets=None, order_by=None) -> PaginatedCollection: query_str = """query GetProjectLabelsPyApi($%s: ID!) {project (where: {id: $%s}) {labels (skip: %%d first: %%d %s %s) {%s}}}""" % ( - id_param, id_param, where, order_by_str, - query.results_query_part(Label)) + id_param, + id_param, + where, + order_by_str, + query.results_query_part(Label), + ) - return PaginatedCollection(self.client, query_str, {id_param: self.uid}, - ["project", "labels"], Label) + return PaginatedCollection( + self.client, + query_str, + {id_param: self.uid}, + ["project", "labels"], + Label, + ) def export_queued_data_rows( - self, - timeout_seconds=120, - include_metadata: bool = False) -> List[Dict[str, str]]: - """ Returns all data rows that are currently enqueued for this project. + self, timeout_seconds=120, include_metadata: bool = False + ) -> List[Dict[str, str]]: + """Returns all data rows that are currently enqueued for this project. Args: timeout_seconds (float): Max waiting time, in seconds. @@ -329,7 +394,8 @@ def export_queued_data_rows( """ warnings.warn( "You are currently utilizing exports v1 for this action, which will be deprecated after April 30th, 2024. We recommend transitioning to exports v2. To view export v2 details, visit our docs: https://docs.labelbox.com/reference/label-export", - DeprecationWarning) + DeprecationWarning, + ) id_param = "projectId" metadata_param = "includeMetadataInput" query_str = """mutation GetQueuedDataRowsExportUrlPyApi($%s: ID!, $%s: Boolean!) @@ -338,10 +404,10 @@ def export_queued_data_rows( sleep_time = 2 start_time = time.time() while True: - res = self.client.execute(query_str, { - id_param: self.uid, - metadata_param: include_metadata - }) + res = self.client.execute( + query_str, + {id_param: self.uid, metadata_param: include_metadata}, + ) res = res["exportQueuedDataRows"] if res["status"] == "COMPLETE": download_url = res["downloadUrl"] @@ -359,14 +425,14 @@ def export_queued_data_rows( logger.debug( "Project '%s' queued data row export, waiting for server...", - self.uid) + self.uid, + ) time.sleep(sleep_time) - def export_labels(self, - download=False, - timeout_seconds=1800, - **kwargs) -> Optional[Union[str, List[Dict[Any, Any]]]]: - """ Calls the server-side Label exporting that generates a JSON + def export_labels( + self, download=False, timeout_seconds=1800, **kwargs + ) -> Optional[Union[str, List[Dict[Any, Any]]]]: + """Calls the server-side Label exporting that generates a JSON payload, and returns the URL to that payload. Will only generate a new URL at a max frequency of 30 min. @@ -389,7 +455,8 @@ def export_labels(self, """ warnings.warn( "You are currently utilizing exports v1 for this action, which will be deprecated after April 30th, 2024. We recommend transitioning to exports v2. To view export v2 details, visit our docs: https://docs.labelbox.com/reference/label-export", - DeprecationWarning) + DeprecationWarning, + ) def _string_from_dict(dictionary: dict, value_with_quotes=False) -> str: """Returns a concatenated string of the dictionary's keys and values @@ -397,12 +464,14 @@ def _string_from_dict(dictionary: dict, value_with_quotes=False) -> str: The string will be formatted as {key}: 'value' for each key. Value will be inclusive of quotations while key will not. This can be toggled with `value_with_quotes`""" - quote = "\"" if value_with_quotes else "" - return ",".join([ - f"""{c}: {quote}{dictionary.get(c)}{quote}""" - for c in dictionary - if dictionary.get(c) - ]) + quote = '"' if value_with_quotes else "" + return ",".join( + [ + f"""{c}: {quote}{dictionary.get(c)}{quote}""" + for c in dictionary + if dictionary.get(c) + ] + ) sleep_time = 2 id_param = "projectId" @@ -412,15 +481,16 @@ def _string_from_dict(dictionary: dict, value_with_quotes=False) -> str: if "start" in kwargs or "end" in kwargs: created_at_dict = { "start": kwargs.get("start", ""), - "end": kwargs.get("end", "") + "end": kwargs.get("end", ""), } [validate_datetime(date) for date in created_at_dict.values()] filter_param_dict["labelCreatedAt"] = "{%s}" % _string_from_dict( - created_at_dict, value_with_quotes=True) + created_at_dict, value_with_quotes=True + ) if "last_activity_start" in kwargs or "last_activity_end" in kwargs: - last_activity_start = kwargs.get('last_activity_start') - last_activity_end = kwargs.get('last_activity_end') + last_activity_start = kwargs.get("last_activity_start") + last_activity_end = kwargs.get("last_activity_end") if last_activity_start: validate_datetime(str(last_activity_start)) @@ -428,15 +498,14 @@ def _string_from_dict(dictionary: dict, value_with_quotes=False) -> str: validate_datetime(str(last_activity_end)) filter_param_dict["lastActivityAt"] = "{%s}" % _string_from_dict( - { - "start": last_activity_start, - "end": last_activity_end - }, - value_with_quotes=True) + {"start": last_activity_start, "end": last_activity_end}, + value_with_quotes=True, + ) if filter_param_dict: - filter_param = """, filters: {%s }""" % (_string_from_dict( - filter_param_dict, value_with_quotes=False)) + filter_param = """, filters: {%s }""" % ( + _string_from_dict(filter_param_dict, value_with_quotes=False) + ) query_str = """mutation GetLabelExportUrlPyApi($%s: ID!) {exportLabels(data:{projectId: $%s%s}) {downloadUrl createdAt shouldPoll} } @@ -448,7 +517,7 @@ def _string_from_dict(dictionary: dict, value_with_quotes=False) -> str: res = self.client.execute(query_str, {id_param: self.uid}) res = res["exportLabels"] if not res["shouldPoll"] and res["downloadUrl"] is not None: - url = res['downloadUrl'] + url = res["downloadUrl"] if not download: return url else: @@ -460,8 +529,9 @@ def _string_from_dict(dictionary: dict, value_with_quotes=False) -> str: if current_time - start_time > timeout_seconds: return None - logger.debug("Project '%s' label export, waiting for server...", - self.uid) + logger.debug( + "Project '%s' label export, waiting for server...", self.uid + ) time.sleep(sleep_time) def export( @@ -516,7 +586,7 @@ def export_v2( >>> task.result """ task, is_streamable = self._export(task_name, filters, params) - if (is_streamable): + if is_streamable: return ExportTask(task, True) return task @@ -527,34 +597,39 @@ def _export( params: Optional[ProjectExportParams] = None, streamable: bool = False, ) -> Tuple[Task, bool]: - _params = params or ProjectExportParams({ - "attachments": False, - "embeddings": False, - "metadata_fields": False, - "data_row_details": False, - "project_details": False, - "performance_details": False, - "label_details": False, - "media_type_override": None, - "interpolated_frames": False, - }) - - _filters = filters or ProjectExportFilters({ - "last_activity_at": None, - "label_created_at": None, - "data_row_ids": None, - "global_keys": None, - "batch_ids": None, - "workflow_status": None - }) + _params = params or ProjectExportParams( + { + "attachments": False, + "embeddings": False, + "metadata_fields": False, + "data_row_details": False, + "project_details": False, + "performance_details": False, + "label_details": False, + "media_type_override": None, + "interpolated_frames": False, + } + ) + + _filters = filters or ProjectExportFilters( + { + "last_activity_at": None, + "label_created_at": None, + "data_row_ids": None, + "global_keys": None, + "batch_ids": None, + "workflow_status": None, + } + ) mutation_name = "exportDataRowsInProject" create_task_query_str = ( f"mutation {mutation_name}PyApi" f"($input: ExportDataRowsInProjectInput!)" - f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}") + f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}" + ) - media_type_override = _params.get('media_type_override', None) + media_type_override = _params.get("media_type_override", None) query_params: Dict[str, Any] = { "input": { "taskName": task_name, @@ -564,28 +639,28 @@ def _export( "searchQuery": { "scope": None, "query": [], - } + }, }, "params": { - "mediaTypeOverride": - media_type_override.value - if media_type_override is not None else None, - "includeAttachments": - _params.get('attachments', False), - "includeEmbeddings": - _params.get('embeddings', False), - "includeMetadata": - _params.get('metadata_fields', False), - "includeDataRowDetails": - _params.get('data_row_details', False), - "includeProjectDetails": - _params.get('project_details', False), - "includePerformanceDetails": - _params.get('performance_details', False), - "includeLabelDetails": - _params.get('label_details', False), - "includeInterpolatedFrames": - _params.get('interpolated_frames', False), + "mediaTypeOverride": media_type_override.value + if media_type_override is not None + else None, + "includeAttachments": _params.get("attachments", False), + "includeEmbeddings": _params.get("embeddings", False), + "includeMetadata": _params.get("metadata_fields", False), + "includeDataRowDetails": _params.get( + "data_row_details", False + ), + "includeProjectDetails": _params.get( + "project_details", False + ), + "includePerformanceDetails": _params.get( + "performance_details", False + ), + "includeLabelDetails": _params.get("label_details", False), + "includeInterpolatedFrames": _params.get( + "interpolated_frames", False + ), }, "streamable": streamable, } @@ -594,16 +669,16 @@ def _export( search_query = build_filters(self.client, _filters) query_params["input"]["filters"]["searchQuery"]["query"] = search_query - res = self.client.execute(create_task_query_str, - query_params, - error_log_key="errors") + res = self.client.execute( + create_task_query_str, query_params, error_log_key="errors" + ) res = res[mutation_name] task_id = res["taskId"] is_streamable = res["isStreamable"] return Task.get_task(self.client, task_id), is_streamable def export_issues(self, status=None) -> str: - """ Calls the server-side Issues exporting that + """Calls the server-side Issues exporting that returns the URL to that payload. Args: @@ -622,19 +697,19 @@ def export_issues(self, status=None) -> str: valid_statuses = {None, "Open", "Resolved"} if status not in valid_statuses: - raise ValueError("status must be in {}. Found {}".format( - valid_statuses, status)) + raise ValueError( + "status must be in {}. Found {}".format(valid_statuses, status) + ) - res = self.client.execute(query_str, { - id_param: self.uid, - status_param: status - }) + res = self.client.execute( + query_str, {id_param: self.uid, status_param: status} + ) - res = res['project'] + res = res["project"] logger.debug("Project '%s' issues export, link generated", self.uid) - return res.get('issueExportUrl') + return res.get("issueExportUrl") def upsert_instructions(self, instructions_file: str) -> None: """ @@ -660,7 +735,8 @@ def upsert_instructions(self, instructions_file: str) -> None: if frontend.name != "Editor": logger.warning( f"This function has only been tested to work with the Editor front end. Found %s", - frontend.name) + frontend.name, + ) supported_instruction_formats = (".pdf", ".html") if not instructions_file.endswith(supported_instruction_formats): @@ -683,13 +759,13 @@ def upsert_instructions(self, instructions_file: str) -> None: } }""" - self.client.execute(query_str, { - 'projectId': self.uid, - 'instructions_url': instructions_url - }) + self.client.execute( + query_str, + {"projectId": self.uid, "instructions_url": instructions_url}, + ) def labeler_performance(self) -> PaginatedCollection: - """ Returns the labeler performances for this Project. + """Returns the labeler performances for this Project. Returns: A PaginatedCollection of LabelerPerformance objects. @@ -706,17 +782,25 @@ def create_labeler_performance(client, result): result["user"] = Entity.User(client, result["user"]) # python isoformat doesn't accept Z as utc timezone result["lastActivityTime"] = utils.format_iso_from_string( - result["lastActivityTime"].replace('Z', '+00:00')) - return LabelerPerformance(**{ - utils.snake_case(key): value for key, value in result.items() - }) + result["lastActivityTime"].replace("Z", "+00:00") + ) + return LabelerPerformance( + **{ + utils.snake_case(key): value + for key, value in result.items() + } + ) - return PaginatedCollection(self.client, query_str, {id_param: self.uid}, - ["project", "labelerPerformance"], - create_labeler_performance) + return PaginatedCollection( + self.client, + query_str, + {id_param: self.uid}, + ["project", "labelerPerformance"], + create_labeler_performance, + ) def review_metrics(self, net_score) -> int: - """ Returns this Project's review metrics. + """Returns this Project's review metrics. Args: net_score (None or Review.NetScore): Indicates desired metric. @@ -726,7 +810,8 @@ def review_metrics(self, net_score) -> int: if net_score not in (None,) + tuple(Entity.Review.NetScore): raise InvalidQueryError( "Review metrics net score must be either None " - "or one of Review.NetScore values") + "or one of Review.NetScore values" + ) id_param = "projectId" net_score_literal = "None" if net_score is None else net_score.name query_str = """query ProjectReviewMetricsPyApi($%s: ID!){ @@ -758,24 +843,23 @@ def connect_ontology(self, ontology) -> None: if not self.is_empty_ontology(): raise ValueError("Ontology already connected to project.") - if self.labeling_frontend( - ) is None: # Chat evaluation projects are automatically set up via the same api that creates a project - self._connect_default_labeling_front_end(ontology_as_dict={ - "tools": [], - "classifications": [] - }) + if ( + self.labeling_frontend() is None + ): # Chat evaluation projects are automatically set up via the same api that creates a project + self._connect_default_labeling_front_end( + ontology_as_dict={"tools": [], "classifications": []} + ) query_str = """mutation ConnectOntologyPyApi($projectId: ID!, $ontologyId: ID!){ project(where: {id: $projectId}) {connectOntology(ontologyId: $ontologyId) {id}}}""" - self.client.execute(query_str, { - 'ontologyId': ontology.uid, - 'projectId': self.uid - }) + self.client.execute( + query_str, {"ontologyId": ontology.uid, "projectId": self.uid} + ) timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") self.update(setup_complete=timestamp) def setup(self, labeling_frontend, labeling_frontend_options) -> None: - """ This method will associate default labeling frontend with the project and create an ontology based on labeling_frontend_options. + """This method will associate default labeling frontend with the project and create an ontology based on labeling_frontend_options. Args: labeling_frontend (LabelingFrontend): Do not use, this parameter is deprecated. We now associate the default labeling frontend with the project. @@ -804,11 +888,15 @@ def setup(self, labeling_frontend, labeling_frontend_options) -> None: def _connect_default_labeling_front_end(self, ontology_as_dict: dict): labeling_frontend = self.labeling_frontend() - if labeling_frontend is None: # Chat evaluation projects are automatically set up via the same api that creates a project + if ( + labeling_frontend is None + ): # Chat evaluation projects are automatically set up via the same api that creates a project warnings.warn("Connecting default labeling editor for the project.") labeling_frontend = next( self.client.get_labeling_frontends( - where=Entity.LabelingFrontend.name == "Editor")) + where=Entity.LabelingFrontend.name == "Editor" + ) + ) self.labeling_frontend.connect(labeling_frontend) if not isinstance(ontology_as_dict, str): @@ -818,11 +906,13 @@ def _connect_default_labeling_front_end(self, ontology_as_dict: dict): LFO = Entity.LabelingFrontendOptions self.client._create( - LFO, { + LFO, + { LFO.project: self, LFO.labeling_frontend: labeling_frontend, - LFO.customization_options: labeling_frontend_options_str - }) + LFO.customization_options: labeling_frontend_options_str, + }, + ) def create_batch( self, @@ -855,7 +945,8 @@ def create_batch( if self.is_auto_data_generation(): raise ValueError( - "Cannot create batches for auto data generation projects") + "Cannot create batches for auto data generation projects" + ) dr_ids = [] if data_rows is not None: @@ -866,7 +957,8 @@ def create_batch( dr_ids.append(dr) else: raise ValueError( - "`data_rows` must be DataRow ids or DataRow objects") + "`data_rows` must be DataRow ids or DataRow objects" + ) if data_rows is not None: row_count = len(dr_ids) @@ -877,23 +969,28 @@ def create_batch( if row_count > 100_000: raise ValueError( - f"Batch exceeds max size, break into smaller batches") + f"Batch exceeds max size, break into smaller batches" + ) if not row_count: raise ValueError("You need at least one data row in a batch") self._wait_until_data_rows_are_processed( - dr_ids, global_keys, self._wait_processing_max_seconds) + dr_ids, global_keys, self._wait_processing_max_seconds + ) if consensus_settings: - consensus_settings = ConsensusSettings(**consensus_settings).model_dump( - by_alias=True) + consensus_settings = ConsensusSettings( + **consensus_settings + ).model_dump(by_alias=True) if row_count >= MAX_SYNC_BATCH_ROW_COUNT: - return self._create_batch_async(name, dr_ids, global_keys, priority, - consensus_settings) + return self._create_batch_async( + name, dr_ids, global_keys, priority, consensus_settings + ) else: - return self._create_batch_sync(name, dr_ids, global_keys, priority, - consensus_settings) + return self._create_batch_sync( + name, dr_ids, global_keys, priority, consensus_settings + ) def create_batches( self, @@ -936,16 +1033,19 @@ def create_batches( dr_ids.append(dr) else: raise ValueError( - "`data_rows` must be DataRow ids or DataRow objects") + "`data_rows` must be DataRow ids or DataRow objects" + ) self._wait_until_data_rows_are_processed( - dr_ids, global_keys, self._wait_processing_max_seconds) + dr_ids, global_keys, self._wait_processing_max_seconds + ) if consensus_settings: - consensus_settings = ConsensusSettings(**consensus_settings).model_dump( - by_alias=True) + consensus_settings = ConsensusSettings( + **consensus_settings + ).model_dump(by_alias=True) - method = 'createBatches' + method = "createBatches" mutation_str = """mutation %sPyApi($projectId: ID!, $input: CreateBatchesInput!) { project(where: {id: $projectId}) { %s(input: $input) { @@ -965,12 +1065,13 @@ def create_batches( "dataRowIds": dr_ids, "globalKeys": global_keys, "priority": priority, - "consensusSettings": consensus_settings - } + "consensusSettings": consensus_settings, + }, } - tasks = self.client.execute( - mutation_str, params, experimental=True)["project"][method]["tasks"] + tasks = self.client.execute(mutation_str, params, experimental=True)[ + "project" + ][method]["tasks"] batch_ids = [task["batchUuid"] for task in tasks] task_ids = [task["taskId"] for task in tasks] @@ -981,8 +1082,8 @@ def create_batches_from_dataset( name_prefix: str, dataset_id: str, priority: int = 5, - consensus_settings: Optional[Dict[str, - Any]] = None) -> CreateBatchesTask: + consensus_settings: Optional[Dict[str, Any]] = None, + ) -> CreateBatchesTask: """ Creates batches for a project from a dataset, selecting only the data rows that are not already added to the project. When the dataset contains more than 100k data rows and multiple batches are needed, the specific batch @@ -1009,10 +1110,11 @@ def create_batches_from_dataset( raise ValueError("Project must be in batch mode") if consensus_settings: - consensus_settings = ConsensusSettings(**consensus_settings).model_dump( - by_alias=True) + consensus_settings = ConsensusSettings( + **consensus_settings + ).model_dump(by_alias=True) - method = 'createBatchesFromDataset' + method = "createBatchesFromDataset" mutation_str = """mutation %sPyApi($projectId: ID!, $input: CreateBatchesFromDatasetInput!) { project(where: {id: $projectId}) { %s(input: $input) { @@ -1031,21 +1133,23 @@ def create_batches_from_dataset( "batchNamePrefix": name_prefix, "datasetId": dataset_id, "priority": priority, - "consensusSettings": consensus_settings - } + "consensusSettings": consensus_settings, + }, } - tasks = self.client.execute( - mutation_str, params, experimental=True)["project"][method]["tasks"] + tasks = self.client.execute(mutation_str, params, experimental=True)[ + "project" + ][method]["tasks"] batch_ids = [task["batchUuid"] for task in tasks] task_ids = [task["taskId"] for task in tasks] return CreateBatchesTask(self.client, self.uid, batch_ids, task_ids) - def _create_batch_sync(self, name, dr_ids, global_keys, priority, - consensus_settings): - method = 'createBatchV2' + def _create_batch_sync( + self, name, dr_ids, global_keys, priority, consensus_settings + ): + method = "createBatchV2" query_str = """mutation %sPyApi($projectId: ID!, $batchInput: CreateBatchInput!) { project(where: {id: $projectId}) { %s(input: $batchInput) { @@ -1064,28 +1168,30 @@ def _create_batch_sync(self, name, dr_ids, global_keys, priority, "dataRowIds": dr_ids, "globalKeys": global_keys, "priority": priority, - "consensusSettings": consensus_settings - } + "consensusSettings": consensus_settings, + }, } - res = self.client.execute(query_str, - params, - timeout=180.0, - experimental=True)["project"][method] - batch = res['batch'] - batch['size'] = res['batch']['size'] - return Entity.Batch(self.client, - self.uid, - batch, - failed_data_row_ids=res['failedDataRowIds']) - - def _create_batch_async(self, - name: str, - dr_ids: Optional[List[str]] = None, - global_keys: Optional[List[str]] = None, - priority: int = 5, - consensus_settings: Optional[Dict[str, - float]] = None): - method = 'createEmptyBatch' + res = self.client.execute( + query_str, params, timeout=180.0, experimental=True + )["project"][method] + batch = res["batch"] + batch["size"] = res["batch"]["size"] + return Entity.Batch( + self.client, + self.uid, + batch, + failed_data_row_ids=res["failedDataRowIds"], + ) + + def _create_batch_async( + self, + name: str, + dr_ids: Optional[List[str]] = None, + global_keys: Optional[List[str]] = None, + priority: int = 5, + consensus_settings: Optional[Dict[str, float]] = None, + ): + method = "createEmptyBatch" create_empty_batch_mutation_str = """mutation %sPyApi($projectId: ID!, $input: CreateEmptyBatchInput!) { project(where: {id: $projectId}) { %s(input: $input) { @@ -1097,19 +1203,18 @@ def _create_batch_async(self, params = { "projectId": self.uid, - "input": { - "name": name, - "consensusSettings": consensus_settings - } + "input": {"name": name, "consensusSettings": consensus_settings}, } - res = self.client.execute(create_empty_batch_mutation_str, - params, - timeout=180.0, - experimental=True)["project"][method] - batch_id = res['id'] + res = self.client.execute( + create_empty_batch_mutation_str, + params, + timeout=180.0, + experimental=True, + )["project"][method] + batch_id = res["id"] - method = 'addDataRowsToBatchAsync' + method = "addDataRowsToBatchAsync" add_data_rows_mutation_str = """mutation %sPyApi($projectId: ID!, $input: AddDataRowsToBatchInput!) { project(where: {id: $projectId}) { %s(input: $input) { @@ -1126,20 +1231,21 @@ def _create_batch_async(self, "dataRowIds": dr_ids, "globalKeys": global_keys, "priority": priority, - } + }, } - res = self.client.execute(add_data_rows_mutation_str, - params, - timeout=180.0, - experimental=True)["project"][method] + res = self.client.execute( + add_data_rows_mutation_str, params, timeout=180.0, experimental=True + )["project"][method] - task_id = res['taskId'] + task_id = res["taskId"] task = self._wait_for_task(task_id) if task.status != "COMPLETE": - raise LabelboxError(f"Batch was not created successfully: " + - json.dumps(task.errors)) + raise LabelboxError( + f"Batch was not created successfully: " + + json.dumps(task.errors) + ) return self.client.get_batch(self.uid, batch_id) @@ -1173,21 +1279,24 @@ def _update_queue_mode(self, mode: "QueueMode") -> "QueueMode": status = "DISABLED" else: raise ValueError( - "Must provide either `BATCH` or `DATASET` as a mode") + "Must provide either `BATCH` or `DATASET` as a mode" + ) - query_str = """mutation %s($projectId: ID!, $status: TagSetStatusInput!) { + query_str = ( + """mutation %s($projectId: ID!, $status: TagSetStatusInput!) { project(where: {id: $projectId}) { setTagSetStatus(input: {tagSetStatus: $status}) { tagSetStatus } } } - """ % "setTagSetStatusPyApi" + """ + % "setTagSetStatusPyApi" + ) - self.client.execute(query_str, { - 'projectId': self.uid, - 'status': status - }) + self.client.execute( + query_str, {"projectId": self.uid, "status": status} + ) return mode @@ -1202,7 +1311,7 @@ def get_label_count(self) -> int: } }""" - res = self.client.execute(query_str, {'projectId': self.uid}) + res = self.client.execute(query_str, {"projectId": self.uid}) return res["project"]["labelCount"] def get_queue_mode(self) -> "QueueMode": @@ -1221,17 +1330,22 @@ def get_queue_mode(self) -> "QueueMode": logger.warning( "Obtaining the queue_mode for a project through this method will soon" - " no longer be supported.") + " no longer be supported." + ) - query_str = """query %s($projectId: ID!) { + query_str = ( + """query %s($projectId: ID!) { project(where: {id: $projectId}) { tagSetStatus } } - """ % "GetTagSetStatusPyApi" + """ + % "GetTagSetStatusPyApi" + ) - status = self.client.execute( - query_str, {'projectId': self.uid})["project"]["tagSetStatus"] + status = self.client.execute(query_str, {"projectId": self.uid})[ + "project" + ]["tagSetStatus"] if status == "ENABLED": return QueueMode.Batch @@ -1241,7 +1355,7 @@ def get_queue_mode(self) -> "QueueMode": raise ValueError("Status not known") def add_model_config(self, model_config_id: str) -> str: - """ Adds a model config to this project. + """Adds a model config to this project. Args: model_config_id (str): ID of a model config to add to this project. @@ -1264,10 +1378,11 @@ def add_model_config(self, model_config_id: str) -> str: result = self.client.execute(query, params) except LabelboxError as e: if e.message.startswith( - "Unknown error: " + "Unknown error: " ): # unfortunate hack to handle unparsed graphql errors error_content = error_message_for_unparsed_graphql_error( - e.message) + e.message + ) else: error_content = e.message raise LabelboxError(message=error_content) from e @@ -1277,7 +1392,7 @@ def add_model_config(self, model_config_id: str) -> str: return result["createProjectModelConfig"]["projectModelConfigId"] def delete_project_model_config(self, project_model_config_id: str) -> bool: - """ Deletes the association between a model config and this project. + """Deletes the association between a model config and this project. Args: project_model_config_id (str): ID of a project model config association to delete for this project. @@ -1319,12 +1434,14 @@ def set_project_model_setup_complete(self) -> bool: result = self.client.execute(query, {"projectId": self.uid}) self.model_setup_complete = result["setProjectModelSetupComplete"][ - "modelSetupComplete"] + "modelSetupComplete" + ] return result["setProjectModelSetupComplete"]["modelSetupComplete"] def set_labeling_parameter_overrides( - self, data: List[LabelingParameterOverrideInput]) -> bool: - """ Adds labeling parameter overrides to this project. + self, data: List[LabelingParameterOverrideInput] + ) -> bool: + """Adds labeling parameter overrides to this project. See information on priority here: https://docs.labelbox.com/en/configure-editor/queue-system#reservation-system @@ -1364,22 +1481,25 @@ def set_labeling_parameter_overrides( {setLabelingParameterOverrides (dataWithDataRowIdentifiers: [$dataWithDataRowIdentifiers]) {success}}} - """) + """ + ) data_rows_with_identifiers = "" for data_row, priority in data: if isinstance(data_row, DataRow): - data_rows_with_identifiers += f"{{dataRowIdentifier: {{id: \"{data_row.uid}\", idType: {IdType.DataRowId}}}, priority: {priority}}}," + data_rows_with_identifiers += f'{{dataRowIdentifier: {{id: "{data_row.uid}", idType: {IdType.DataRowId}}}, priority: {priority}}},' elif isinstance(data_row, UniqueId) or isinstance( - data_row, GlobalKey): - data_rows_with_identifiers += f"{{dataRowIdentifier: {{id: \"{data_row.key}\", idType: {data_row.id_type}}}, priority: {priority}}}," + data_row, GlobalKey + ): + data_rows_with_identifiers += f'{{dataRowIdentifier: {{id: "{data_row.key}", idType: {data_row.id_type}}}, priority: {priority}}},' else: raise TypeError( f"Data row identifier should be be of type DataRow or Data Row Identifier. Found {type(data_row)}." ) query_str = template.substitute( - dataWithDataRowIdentifiers=data_rows_with_identifiers) + dataWithDataRowIdentifiers=data_rows_with_identifiers + ) res = self.client.execute(query_str, {"projectId": self.uid}) return res["project"]["setLabelingParameterOverrides"]["success"] @@ -1422,8 +1542,10 @@ def update_data_row_labeling_priority( if isinstance(data_rows, list): data_rows = UniqueIds(data_rows) - warnings.warn("Using data row ids will be deprecated. Please use " - "UniqueIds or GlobalKeys instead.") + warnings.warn( + "Using data row ids will be deprecated. Please use " + "UniqueIds or GlobalKeys instead." + ) method = "createQueuePriorityUpdateTask" priority_param = "priority" @@ -1442,28 +1564,40 @@ def update_data_row_labeling_priority( } } } - """ % (method, priority_param, project_param, data_rows_param, - project_param, method, priority_param, data_rows_param) + """ % ( + method, + priority_param, + project_param, + data_rows_param, + project_param, + method, + priority_param, + data_rows_param, + ) res = self.client.execute( - query_str, { + query_str, + { priority_param: priority, project_param: self.uid, data_rows_param: { "ids": [id for id in data_rows], "idType": data_rows.id_type, }, - })["project"][method] + }, + )["project"][method] - task_id = res['taskId'] + task_id = res["taskId"] task = self._wait_for_task(task_id) if task.status != "COMPLETE": - raise LabelboxError(f"Priority was not updated successfully: " + - json.dumps(task.errors)) + raise LabelboxError( + f"Priority was not updated successfully: " + + json.dumps(task.errors) + ) return True def extend_reservations(self, queue_type) -> int: - """ Extends all the current reservations for the current user on the given + """Extends all the current reservations for the current user on the given queue type. Args: queue_type (str): Either "LabelingQueue" or "ReviewQueue" @@ -1476,12 +1610,15 @@ def extend_reservations(self, queue_type) -> int: id_param = "projectId" query_str = """mutation ExtendReservationsPyApi($%s: ID!){ extendReservations(projectId:$%s queueType:%s)}""" % ( - id_param, id_param, queue_type) + id_param, + id_param, + queue_type, + ) res = self.client.execute(query_str, {id_param: self.uid}) return res["extendReservations"] def enable_model_assisted_labeling(self, toggle: bool = True) -> bool: - """ Turns model assisted labeling either on or off based on input + """Turns model assisted labeling either on or off based on input Args: toggle (bool): True or False boolean @@ -1503,10 +1640,11 @@ def enable_model_assisted_labeling(self, toggle: bool = True) -> bool: res = self.client.execute(query_str, params) return res["project"]["showPredictionsToLabelers"][ - "showingPredictionsToLabelers"] + "showingPredictionsToLabelers" + ] def bulk_import_requests(self) -> PaginatedCollection: - """ Returns bulk import request objects which are used in model-assisted labeling. + """Returns bulk import request objects which are used in model-assisted labeling. These are returned with the oldest first, and most recent last. """ @@ -1519,15 +1657,21 @@ def bulk_import_requests(self) -> PaginatedCollection: ) { %s } - }""" % (id_param, id_param, - query.results_query_part(Entity.BulkImportRequest)) - return PaginatedCollection(self.client, query_str, - {id_param: str(self.uid)}, - ["bulkImportRequests"], - Entity.BulkImportRequest) + }""" % ( + id_param, + id_param, + query.results_query_part(Entity.BulkImportRequest), + ) + return PaginatedCollection( + self.client, + query_str, + {id_param: str(self.uid)}, + ["bulkImportRequests"], + Entity.BulkImportRequest, + ) def batches(self) -> PaginatedCollection: - """ Fetch all batches that belong to this project + """Fetch all batches that belong to this project Returns: A `PaginatedCollection` of `Batch`es @@ -1539,13 +1683,16 @@ def batches(self) -> PaginatedCollection: """ % (id_param, id_param, query.results_query_part(Entity.Batch)) return PaginatedCollection( self.client, - query_str, {id_param: self.uid}, ['project', 'batches', 'nodes'], + query_str, + {id_param: self.uid}, + ["project", "batches", "nodes"], lambda client, res: Entity.Batch(client, self.uid, res), - cursor_path=['project', 'batches', 'pageInfo', 'endCursor'], - experimental=True) + cursor_path=["project", "batches", "pageInfo", "endCursor"], + experimental=True, + ) def task_queues(self) -> List[TaskQueue]: - """ Fetch all task queues that belong to this project + """Fetch all task queues that belong to this project Returns: A `List` of `TaskQueue`s @@ -1560,9 +1707,8 @@ def task_queues(self) -> List[TaskQueue]: """ % (query.results_query_part(Entity.TaskQueue)) task_queue_values = self.client.execute( - query_str, {"projectId": self.uid}, - timeout=180.0, - experimental=True)["project"]["taskQueues"] + query_str, {"projectId": self.uid}, timeout=180.0, experimental=True + )["project"]["taskQueues"] return [ Entity.TaskQueue(self.client, field_values) @@ -1570,13 +1716,15 @@ def task_queues(self) -> List[TaskQueue]: ] @overload - def move_data_rows_to_task_queue(self, data_row_ids: DataRowIdentifiers, - task_queue_id: str): + def move_data_rows_to_task_queue( + self, data_row_ids: DataRowIdentifiers, task_queue_id: str + ): pass @overload - def move_data_rows_to_task_queue(self, data_row_ids: List[str], - task_queue_id: str): + def move_data_rows_to_task_queue( + self, data_row_ids: List[str], task_queue_id: str + ): pass def move_data_rows_to_task_queue(self, data_row_ids, task_queue_id: str): @@ -1595,11 +1743,14 @@ def move_data_rows_to_task_queue(self, data_row_ids, task_queue_id: str): """ if isinstance(data_row_ids, list): data_row_ids = UniqueIds(data_row_ids) - warnings.warn("Using data row ids will be deprecated. Please use " - "UniqueIds or GlobalKeys instead.") + warnings.warn( + "Using data row ids will be deprecated. Please use " + "UniqueIds or GlobalKeys instead." + ) method = "createBulkAddRowsToQueueTask" - query_str = """mutation AddDataRowsToTaskQueueAsyncPyApi( + query_str = ( + """mutation AddDataRowsToTaskQueueAsyncPyApi( $projectId: ID! $queueId: ID $dataRowIdentifiers: AddRowsToTaskQueueViaDataRowIdentifiersInput! @@ -1612,10 +1763,13 @@ def move_data_rows_to_task_queue(self, data_row_ids, task_queue_id: str): } } } - """ % method + """ + % method + ) task_id = self.client.execute( - query_str, { + query_str, + { "projectId": self.uid, "queueId": task_queue_id, "dataRowIdentifiers": { @@ -1624,12 +1778,15 @@ def move_data_rows_to_task_queue(self, data_row_ids, task_queue_id: str): }, }, timeout=180.0, - experimental=True)["project"][method]["taskId"] + experimental=True, + )["project"][method]["taskId"] task = self._wait_for_task(task_id) if task.status != "COMPLETE": - raise LabelboxError(f"Data rows were not moved successfully: " + - json.dumps(task.errors)) + raise LabelboxError( + f"Data rows were not moved successfully: " + + json.dumps(task.errors) + ) def _wait_for_task(self, task_id: str) -> Task: task = Task.get_task(self.client, task_id) @@ -1638,11 +1795,12 @@ def _wait_for_task(self, task_id: str) -> Task: return task def upload_annotations( - self, - name: str, - annotations: Union[str, Path, Iterable[Dict]], - validate: bool = False) -> 'BulkImportRequest': # type: ignore - """ Uploads annotations to a new Editor project. + self, + name: str, + annotations: Union[str, Path, Iterable[Dict]], + validate: bool = False, + ) -> "BulkImportRequest": # type: ignore + """Uploads annotations to a new Editor project. Args: name (str): name of the BulkImportRequest job @@ -1660,7 +1818,7 @@ def upload_annotations( if isinstance(annotations, str) or isinstance(annotations, Path): def _is_url_valid(url: Union[str, Path]) -> bool: - """ Verifies that the given string is a valid url. + """Verifies that the given string is a valid url. Args: url: string to be checked @@ -1679,12 +1837,13 @@ def _is_url_valid(url: Union[str, Path]) -> bool: project_id=self.uid, name=name, url=str(annotations), - validate=validate) + validate=validate, + ) else: path = Path(annotations) if not path.exists(): raise FileNotFoundError( - f'{annotations} is not a valid url nor existing local file' + f"{annotations} is not a valid url nor existing local file" ) return Entity.BulkImportRequest.create_from_local_file( client=self.client, @@ -1699,64 +1858,79 @@ def _is_url_valid(url: Union[str, Path]) -> bool: project_id=self.uid, name=name, predictions=annotations, # type: ignore - validate=validate) + validate=validate, + ) else: raise ValueError( - f'Invalid annotations given of type: {type(annotations)}') + f"Invalid annotations given of type: {type(annotations)}" + ) def _wait_until_data_rows_are_processed( - self, - data_row_ids: Optional[List[str]] = None, - global_keys: Optional[List[str]] = None, - wait_processing_max_seconds: int = _wait_processing_max_seconds, - sleep_interval=30): - """ Wait until all the specified data rows are processed""" + self, + data_row_ids: Optional[List[str]] = None, + global_keys: Optional[List[str]] = None, + wait_processing_max_seconds: int = _wait_processing_max_seconds, + sleep_interval=30, + ): + """Wait until all the specified data rows are processed""" start_time = datetime.now() max_data_rows_per_poll = 100_000 if data_row_ids is not None: for i in range(0, len(data_row_ids), max_data_rows_per_poll): - chunk = data_row_ids[i:i + max_data_rows_per_poll] + chunk = data_row_ids[i : i + max_data_rows_per_poll] self._poll_data_row_processing_status( - chunk, [], start_time, wait_processing_max_seconds, - sleep_interval) + chunk, + [], + start_time, + wait_processing_max_seconds, + sleep_interval, + ) if global_keys is not None: for i in range(0, len(global_keys), max_data_rows_per_poll): - chunk = global_keys[i:i + max_data_rows_per_poll] + chunk = global_keys[i : i + max_data_rows_per_poll] self._poll_data_row_processing_status( - [], chunk, start_time, wait_processing_max_seconds, - sleep_interval) + [], + chunk, + start_time, + wait_processing_max_seconds, + sleep_interval, + ) def _poll_data_row_processing_status( - self, - data_row_ids: List[str], - global_keys: List[str], - start_time: datetime, - wait_processing_max_seconds: int = _wait_processing_max_seconds, - sleep_interval=30): - + self, + data_row_ids: List[str], + global_keys: List[str], + start_time: datetime, + wait_processing_max_seconds: int = _wait_processing_max_seconds, + sleep_interval=30, + ): while True: - if (datetime.now() - - start_time).total_seconds() >= wait_processing_max_seconds: + if ( + datetime.now() - start_time + ).total_seconds() >= wait_processing_max_seconds: raise ProcessingWaitTimeout( """Maximum wait time exceeded while waiting for data rows to be processed. - Try creating a batch a bit later""") + Try creating a batch a bit later""" + ) all_good = self.__check_data_rows_have_been_processed( - data_row_ids, global_keys) + data_row_ids, global_keys + ) if all_good: return logger.debug( - 'Some of the data rows are still being processed, waiting...') + "Some of the data rows are still being processed, waiting..." + ) time.sleep(sleep_interval) def __check_data_rows_have_been_processed( - self, - data_row_ids: Optional[List[str]] = None, - global_keys: Optional[List[str]] = None): - + self, + data_row_ids: Optional[List[str]] = None, + global_keys: Optional[List[str]] = None, + ): if data_row_ids is not None and len(data_row_ids) > 0: param_name = "dataRowIds" params = {param_name: data_row_ids} @@ -1773,11 +1947,12 @@ def __check_data_rows_have_been_processed( response = self.client.execute(query_str, params) return response["queryAllDataRowsHaveBeenProcessed"][ - "allDataRowsHaveBeenProcessed"] + "allDataRowsHaveBeenProcessed" + ] def get_overview( - self, - details=False) -> Union[ProjectOverview, ProjectOverviewDetailed]: + self, details=False + ) -> Union[ProjectOverview, ProjectOverviewDetailed]: """Return the overview of a project. This method returns the number of data rows per task queue and issues of a project, @@ -1816,8 +1991,9 @@ def get_overview( """ # Must use experimental to access "issues" - result = self.client.execute(query, {"projectId": self.uid}, - experimental=True)["project"] + result = self.client.execute( + query, {"projectId": self.uid}, experimental=True + )["project"] # Reformat category names overview = { @@ -1838,16 +2014,14 @@ def get_overview( # Build dictionary for queue details for review and rework queues for category in ["rework", "review"]: queues = [ - { - tq["name"]: tq.get("dataRowCount") - } + {tq["name"]: tq.get("dataRowCount")} for tq in result.get("taskQueues") if tq.get("queueType") == f"MANUAL_{category.upper()}_QUEUE" ] overview[f"in_{category}"] = { "data": queues, - "total": overview[f"in_{category}"] + "total": overview[f"in_{category}"], } return ProjectOverviewDetailed(**overview) @@ -1897,7 +2071,7 @@ def get_labeling_service_dashboard(self) -> LabelingServiceDashboard: """Get the labeling service for this project. Returns: - LabelingServiceDashboard: The labeling service for this project. + LabelingServiceDashboard: The labeling service for this project. Attributes of the dashboard include: id (str): The project id. @@ -1927,12 +2101,13 @@ class ProjectMember(DbObject): class LabelingParameterOverride(DbObject): - """ Customizes the order of assets in the label queue. + """Customizes the order of assets in the label queue. Attributes: priority (int): A prioritization score. number_of_labels (int): Number of times an asset should be labeled. """ + priority = Field.Int("priority") number_of_labels = Field.Int("number_of_labels") @@ -1940,8 +2115,10 @@ class LabelingParameterOverride(DbObject): LabelerPerformance = namedtuple( - "LabelerPerformance", "user count seconds_per_label, total_time_labeling " - "consensus average_benchmark_agreement last_activity_time") + "LabelerPerformance", + "user count seconds_per_label, total_time_labeling " + "consensus average_benchmark_agreement last_activity_time", +) LabelerPerformance.__doc__ = ( - "Named tuple containing info about a labeler's performance.") - + "Named tuple containing info about a labeler's performance." +) diff --git a/libs/labelbox/src/labelbox/schema/project_model_config.py b/libs/labelbox/src/labelbox/schema/project_model_config.py index 9cf6dcbfa..9b6d8a0bb 100644 --- a/libs/labelbox/src/labelbox/schema/project_model_config.py +++ b/libs/labelbox/src/labelbox/schema/project_model_config.py @@ -1,12 +1,15 @@ from labelbox.orm.db_object import DbObject from labelbox.orm.model import Field, Relationship -from labelbox.exceptions import LabelboxError, error_message_for_unparsed_graphql_error +from labelbox.exceptions import ( + LabelboxError, + error_message_for_unparsed_graphql_error, +) class ProjectModelConfig(DbObject): - """ A ProjectModelConfig represents an association between a project and a single model config. + """A ProjectModelConfig represents an association between a project and a single model config. - Attributes: + Attributes: project_id (str): ID of project to associate model_config_id (str): ID of the model configuration model_config (ModelConfig): Configuration for model @@ -17,7 +20,7 @@ class ProjectModelConfig(DbObject): model_config = Relationship.ToOne("ModelConfig", False, "model_config") def delete(self) -> bool: - """ Deletes this association between a model config and this project. + """Deletes this association between a model config and this project. Returns: bool, indicates if the operation was a success. @@ -36,10 +39,11 @@ def delete(self) -> bool: result = self.client.execute(query, params) except LabelboxError as e: if e.message.startswith( - "Unknown error: " + "Unknown error: " ): # unfortunate hack to handle unparsed graphql errors error_content = error_message_for_unparsed_graphql_error( - e.message) + e.message + ) else: error_content = e.message raise LabelboxError(message=error_content) from e diff --git a/libs/labelbox/src/labelbox/schema/project_overview.py b/libs/labelbox/src/labelbox/schema/project_overview.py index 9f6c31e02..cee195c10 100644 --- a/libs/labelbox/src/labelbox/schema/project_overview.py +++ b/libs/labelbox/src/labelbox/schema/project_overview.py @@ -2,9 +2,10 @@ from typing_extensions import TypedDict from pydantic import BaseModel + class ProjectOverview(BaseModel): """ - Class that represents a project summary as displayed in the UI, in Annotate, + Class that represents a project summary as displayed in the UI, in Annotate, under the "Overview" tab of a particular project. All attributes represent the number of data rows in the corresponding state. @@ -19,7 +20,8 @@ class ProjectOverview(BaseModel): The `labeled` attribute represents the number of data rows that have been labeled. The `total_data_rows` attribute represents the total number of data rows in the project. """ - to_label: int + + to_label: int in_review: int in_rework: int skipped: int @@ -32,16 +34,17 @@ class ProjectOverview(BaseModel): class _QueueDetail(TypedDict): """ Class that represents the detailed information of the queues in the project overview. - The `data` attribute is a list of dictionaries where the keys are the queue names + The `data` attribute is a list of dictionaries where the keys are the queue names and the values are the number of data rows in that queue. """ + data: List[Dict[str, int]] total: int - + class ProjectOverviewDetailed(BaseModel): """ - Class that represents a project summary as displayed in the UI, in Annotate, + Class that represents a project summary as displayed in the UI, in Annotate, under the "Overview" tab of a particular project. This class adds the list of task queues for the `in_review` and `in_rework` attributes. @@ -62,11 +65,11 @@ class ProjectOverviewDetailed(BaseModel): The `total_data_rows` attribute represents the total number of data rows in the project. """ - to_label: int + to_label: int in_review: _QueueDetail in_rework: _QueueDetail skipped: int done: int issues: int labeled: int - total_data_rows: int \ No newline at end of file + total_data_rows: int diff --git a/libs/labelbox/src/labelbox/schema/project_resource_tag.py b/libs/labelbox/src/labelbox/schema/project_resource_tag.py index bfb024c5a..18ca94860 100644 --- a/libs/labelbox/src/labelbox/schema/project_resource_tag.py +++ b/libs/labelbox/src/labelbox/schema/project_resource_tag.py @@ -3,7 +3,7 @@ class ProjectResourceTag(DbObject, Updateable): - """ Project resource tag to associate ProjectResourceTag to Project. + """Project resource tag to associate ProjectResourceTag to Project. Attributes: resourceTagId (str) diff --git a/libs/labelbox/src/labelbox/schema/resource_tag.py b/libs/labelbox/src/labelbox/schema/resource_tag.py index b1f5d6e62..8c0559486 100644 --- a/libs/labelbox/src/labelbox/schema/resource_tag.py +++ b/libs/labelbox/src/labelbox/schema/resource_tag.py @@ -3,7 +3,7 @@ class ResourceTag(DbObject, Updateable): - """ Resource tag to label and identify your labelbox resources easier. + """Resource tag to label and identify your labelbox resources easier. Attributes: text (str) diff --git a/libs/labelbox/src/labelbox/schema/review.py b/libs/labelbox/src/labelbox/schema/review.py index a9ae6d9ae..9a6850a28 100644 --- a/libs/labelbox/src/labelbox/schema/review.py +++ b/libs/labelbox/src/labelbox/schema/review.py @@ -5,7 +5,7 @@ class Review(DbObject, Deletable, Updateable): - """ Reviewing labeled data is a collaborative quality assurance technique. + """Reviewing labeled data is a collaborative quality assurance technique. A Review object indicates the quality of the assigned Label. The aggregated review numbers can be obtained on a Project object. @@ -22,8 +22,8 @@ class Review(DbObject, Deletable, Updateable): """ class NetScore(Enum): - """ Negative, Zero, or Positive. - """ + """Negative, Zero, or Positive.""" + Negative = auto() Zero = auto() Positive = auto() diff --git a/libs/labelbox/src/labelbox/schema/role.py b/libs/labelbox/src/labelbox/schema/role.py index 90930fab9..47cd753e9 100644 --- a/libs/labelbox/src/labelbox/schema/role.py +++ b/libs/labelbox/src/labelbox/schema/role.py @@ -16,26 +16,24 @@ def get_roles(client: "Client") -> Dict[str, "Role"]: query_str = """query GetAvailableUserRolesPyApi { roles { id name } }""" res = client.execute(query_str) _ROLES = {} - for role in res['roles']: - role['name'] = format_role(role['name']) - _ROLES[role['name']] = Role(client, role) + for role in res["roles"]: + role["name"] = format_role(role["name"]) + _ROLES[role["name"]] = Role(client, role) return _ROLES def format_role(name: str): - return name.upper().replace(' ', '_') + return name.upper().replace(" ", "_") class Role(DbObject): name = Field.String("name") -class OrgRole(Role): - ... +class OrgRole(Role): ... -class UserRole(Role): - ... +class UserRole(Role): ... @dataclass diff --git a/libs/labelbox/src/labelbox/schema/search_filters.py b/libs/labelbox/src/labelbox/schema/search_filters.py index f2ca7beae..13b158678 100644 --- a/libs/labelbox/src/labelbox/schema/search_filters.py +++ b/libs/labelbox/src/labelbox/schema/search_filters.py @@ -24,15 +24,16 @@ class OperationTypeEnum(Enum): Supported search entity types Each type corresponds to a different filter class """ - Organization = 'organization_id' - SharedWithOrganization = 'shared_with_organizations' - Workspace = 'workspace' - Tag = 'tag' - Stage = 'stage' - WorforceRequestedDate = 'workforce_requested_at' - WorkforceStageUpdatedDate = 'workforce_stage_updated_at' - TaskCompletedCount = 'task_completed_count' - TaskRemainingCount = 'task_remaining_count' + + Organization = "organization_id" + SharedWithOrganization = "shared_with_organizations" + Workspace = "workspace" + Tag = "tag" + Stage = "stage" + WorforceRequestedDate = "workforce_requested_at" + WorkforceStageUpdatedDate = "workforce_stage_updated_at" + TaskCompletedCount = "task_completed_count" + TaskRemainingCount = "task_remaining_count" def convert_enum_to_str(enum_or_str: Union[Enum, str]) -> str: @@ -41,50 +42,58 @@ def convert_enum_to_str(enum_or_str: Union[Enum, str]) -> str: return enum_or_str -OperationType = Annotated[OperationTypeEnum, - PlainSerializer(convert_enum_to_str, return_type=str)] +OperationType = Annotated[ + OperationTypeEnum, PlainSerializer(convert_enum_to_str, return_type=str) +] -IsoDatetimeType = Annotated[datetime.datetime, - PlainSerializer(format_iso_datetime)] +IsoDatetimeType = Annotated[ + datetime.datetime, PlainSerializer(format_iso_datetime) +] class IdOperator(Enum): """ Supported operators for ids like org ids, workspace ids, etc """ - Is = 'is' + + Is = "is" class RangeOperatorWithSingleValue(Enum): """ Supported operators for dates """ - Equals = 'EQUALS' - GreaterThanOrEqual = 'GREATER_THAN_OR_EQUAL' - LessThanOrEqual = 'LESS_THAN_OR_EQUAL' + + Equals = "EQUALS" + GreaterThanOrEqual = "GREATER_THAN_OR_EQUAL" + LessThanOrEqual = "LESS_THAN_OR_EQUAL" class RangeDateTimeOperatorWithSingleValue(Enum): """ Supported operators for dates """ - GreaterThanOrEqual = 'GREATER_THAN_OR_EQUAL' - LessThanOrEqual = 'LESS_THAN_OR_EQUAL' + + GreaterThanOrEqual = "GREATER_THAN_OR_EQUAL" + LessThanOrEqual = "LESS_THAN_OR_EQUAL" class RangeOperatorWithValue(Enum): """ Supported operators for date ranges """ - Between = 'BETWEEN' + + Between = "BETWEEN" class OrganizationFilter(BaseSearchFilter): """ Filter for organization to which projects belong """ - operation: OperationType = Field(default=OperationType.Organization, - serialization_alias='type') + + operation: OperationType = Field( + default=OperationType.Organization, serialization_alias="type" + ) operator: IdOperator values: List[str] @@ -95,8 +104,8 @@ class SharedWithOrganizationFilter(BaseSearchFilter): """ operation: OperationType = Field( - default=OperationType.SharedWithOrganization, - serialization_alias='type') + default=OperationType.SharedWithOrganization, serialization_alias="type" + ) operator: IdOperator values: List[str] @@ -105,8 +114,10 @@ class WorkspaceFilter(BaseSearchFilter): """ Filter for workspace """ - operation: OperationType = Field(default=OperationType.Workspace, - serialization_alias='type') + + operation: OperationType = Field( + default=OperationType.Workspace, serialization_alias="type" + ) operator: IdOperator values: List[str] @@ -116,8 +127,10 @@ class TagFilter(BaseSearchFilter): Filter for project tags values are tag ids """ - operation: OperationType = Field(default=OperationType.Tag, - serialization_alias='type') + + operation: OperationType = Field( + default=OperationType.Tag, serialization_alias="type" + ) operator: IdOperator values: List[str] @@ -127,18 +140,21 @@ class ProjectStageFilter(BaseSearchFilter): Filter labelbox service / aka project stages Stages are: requested, in_progress, completed etc. as described by LabelingServiceStatus """ - operation: OperationType = Field(default=OperationType.Stage, - serialization_alias='type') + + operation: OperationType = Field( + default=OperationType.Stage, serialization_alias="type" + ) operator: IdOperator values: List[LabelingServiceStatus] - @field_validator('values', mode='before') + @field_validator("values", mode="before") def validate_values(cls, values): disallowed_values = [LabelingServiceStatus.Missing] for value in values: if value in disallowed_values: raise ValueError( - f"{value} is not a valid value for ProjectStageFilter") + f"{value} is not a valid value for ProjectStageFilter" + ) return values @@ -155,6 +171,7 @@ class DateValue(BaseSearchFilter): so for a string '2024-01-01' that is run on a computer in PST, we would convert it to '2024-01-01T08:00:00Z' while the same string in EST will get converted to '2024-01-01T05:00:00Z' """ + operator: RangeDateTimeOperatorWithSingleValue value: IsoDatetimeType @@ -168,9 +185,11 @@ class WorkforceStageUpdatedFilter(BaseSearchFilter): """ Filter for workforce stage updated date """ + operation: OperationType = Field( default=OperationType.WorkforceStageUpdatedDate, - serialization_alias='type') + serialization_alias="type", + ) value: DateValue @@ -178,8 +197,10 @@ class WorkforceRequestedDateFilter(BaseSearchFilter): """ Filter for workforce requested date """ + operation: OperationType = Field( - default=OperationType.WorforceRequestedDate, serialization_alias='type') + default=OperationType.WorforceRequestedDate, serialization_alias="type" + ) value: DateValue @@ -187,14 +208,16 @@ class DateRange(BaseSearchFilter): """ Date range for a search filter """ + min: IsoDatetimeType max: IsoDatetimeType class DateRangeValue(BaseSearchFilter): """ - Date range value for a search filter + Date range value for a search filter """ + operator: RangeOperatorWithValue value: DateRange @@ -203,8 +226,10 @@ class WorkforceRequestedDateRangeFilter(BaseSearchFilter): """ Filter for workforce requested date range """ + operation: OperationType = Field( - default=OperationType.WorforceRequestedDate, serialization_alias='type') + default=OperationType.WorforceRequestedDate, serialization_alias="type" + ) value: DateRangeValue @@ -212,9 +237,11 @@ class WorkforceStageUpdatedRangeFilter(BaseSearchFilter): """ Filter for workforce stage updated date range """ + operation: OperationType = Field( default=OperationType.WorkforceStageUpdatedDate, - serialization_alias='type') + serialization_alias="type", + ) value: DateRangeValue @@ -223,8 +250,10 @@ class TaskCompletedCountFilter(BaseSearchFilter): Filter for completed tasks count A task maps to a data row. Task completed should map to a data row in a labeling queue DONE """ - operation: OperationType = Field(default=OperationType.TaskCompletedCount, - serialization_alias='type') + + operation: OperationType = Field( + default=OperationType.TaskCompletedCount, serialization_alias="type" + ) value: IntegerValue @@ -232,27 +261,41 @@ class TaskRemainingCountFilter(BaseSearchFilter): """ Filter for remaining tasks count. Reverse of TaskCompletedCountFilter """ - operation: OperationType = Field(default=OperationType.TaskRemainingCount, - serialization_alias='type') + + operation: OperationType = Field( + default=OperationType.TaskRemainingCount, serialization_alias="type" + ) value: IntegerValue -SearchFilter = Union[OrganizationFilter, WorkspaceFilter, - SharedWithOrganizationFilter, TagFilter, - ProjectStageFilter, WorkforceRequestedDateFilter, - WorkforceStageUpdatedFilter, - WorkforceRequestedDateRangeFilter, - WorkforceStageUpdatedRangeFilter, TaskCompletedCountFilter, - TaskRemainingCountFilter] +SearchFilter = Union[ + OrganizationFilter, + WorkspaceFilter, + SharedWithOrganizationFilter, + TagFilter, + ProjectStageFilter, + WorkforceRequestedDateFilter, + WorkforceStageUpdatedFilter, + WorkforceRequestedDateRangeFilter, + WorkforceStageUpdatedRangeFilter, + TaskCompletedCountFilter, + TaskRemainingCountFilter, +] def _dict_to_graphql_string(d: Union[dict, list, str, int]) -> str: if isinstance(d, dict): - return "{" + ", ".join( - f'{k}: {_dict_to_graphql_string(v)}' for k, v in d.items()) + "}" + return ( + "{" + + ", ".join( + f"{k}: {_dict_to_graphql_string(v)}" for k, v in d.items() + ) + + "}" + ) elif isinstance(d, list): - return "[" + ", ".join( - _dict_to_graphql_string(item) for item in d) + "]" + return ( + "[" + ", ".join(_dict_to_graphql_string(item) for item in d) + "]" + ) else: return f'"{d}"' if isinstance(d, str) else str(d) diff --git a/libs/labelbox/src/labelbox/schema/send_to_annotate_params.py b/libs/labelbox/src/labelbox/schema/send_to_annotate_params.py index f3636e14d..18bd26637 100644 --- a/libs/labelbox/src/labelbox/schema/send_to_annotate_params.py +++ b/libs/labelbox/src/labelbox/schema/send_to_annotate_params.py @@ -2,7 +2,9 @@ from typing import Optional, Dict -from labelbox.schema.conflict_resolution_strategy import ConflictResolutionStrategy +from labelbox.schema.conflict_resolution_strategy import ( + ConflictResolutionStrategy, +) if sys.version_info >= (3, 8): from typing import TypedDict @@ -37,22 +39,24 @@ class SendToAnnotateFromCatalogParams(BaseModel): predictions_ontology_mapping: Optional[Dict[str, str]] = {} annotations_ontology_mapping: Optional[Dict[str, str]] = {} exclude_data_rows_in_project: Optional[bool] = False - override_existing_annotations_rule: Optional[ - ConflictResolutionStrategy] = ConflictResolutionStrategy.KeepExisting + override_existing_annotations_rule: Optional[ConflictResolutionStrategy] = ( + ConflictResolutionStrategy.KeepExisting + ) batch_priority: Optional[int] = 5 @model_validator(mode="after") def check_project_id_or_model_run_id(self): if not self.source_model_run_id and not self.source_project_id: raise ValueError( - 'Either source_project_id or source_model_id are required' + "Either source_project_id or source_model_id are required" ) if self.source_model_run_id and self.source_project_id: raise ValueError( - 'Provide only a source_project_id or source_model_id not both' - ) + "Provide only a source_project_id or source_model_id not both" + ) return self + class SendToAnnotateFromModelParams(TypedDict): """ Extra parameters for sending data rows to a project through a model run. @@ -73,36 +77,35 @@ class SendToAnnotateFromModelParams(TypedDict): batch_priority: Optional[int] -def build_annotations_input(project_ontology_mapping: Optional[Dict[str, str]], - source_project_id: str): +def build_annotations_input( + project_ontology_mapping: Optional[Dict[str, str]], source_project_id: str +): return { - "projectId": - source_project_id, - "featureSchemaIdsMapping": - project_ontology_mapping if project_ontology_mapping else {}, + "projectId": source_project_id, + "featureSchemaIdsMapping": project_ontology_mapping + if project_ontology_mapping + else {}, } def build_destination_task_queue_input(task_queue_id: str): - destination_task_queue = { - "type": "id", - "value": task_queue_id - } if task_queue_id else { - "type": "done" - } + destination_task_queue = ( + {"type": "id", "value": task_queue_id} + if task_queue_id + else {"type": "done"} + ) return destination_task_queue -def build_predictions_input(model_run_ontology_mapping: Optional[Dict[str, - str]], - source_model_run_id: str): +def build_predictions_input( + model_run_ontology_mapping: Optional[Dict[str, str]], + source_model_run_id: str, +): return { - "featureSchemaIdsMapping": - model_run_ontology_mapping if model_run_ontology_mapping else {}, - "modelRunId": - source_model_run_id, - "minConfidence": - 0, - "maxConfidence": - 1 + "featureSchemaIdsMapping": model_run_ontology_mapping + if model_run_ontology_mapping + else {}, + "modelRunId": source_model_run_id, + "minConfidence": 0, + "maxConfidence": 1, } diff --git a/libs/labelbox/src/labelbox/schema/serialization.py b/libs/labelbox/src/labelbox/schema/serialization.py index cfbbb04f8..ca5537fd9 100644 --- a/libs/labelbox/src/labelbox/schema/serialization.py +++ b/libs/labelbox/src/labelbox/schema/serialization.py @@ -5,8 +5,8 @@ def serialize_labels( - objects: Union[List[Dict[str, Any]], - List["Label"]]) -> List[Dict[str, Any]]: + objects: Union[List[Dict[str, Any]], List["Label"]], +) -> List[Dict[str, Any]]: """ Checks if objects are of type Labels and serializes labels for annotation import. Serialization depends the labelbox[data] package, therefore NDJsonConverter is only loaded if using `Label` objects instead of `dict` objects. """ @@ -17,6 +17,7 @@ def serialize_labels( if is_label_type: # If a Label object exists, labelbox[data] is already installed, so no error checking is needed. from labelbox.data.serialization import NDJsonConverter + labels = cast(List["Label"], objects) return list(NDJsonConverter.serialize(labels)) diff --git a/libs/labelbox/src/labelbox/schema/slice.py b/libs/labelbox/src/labelbox/schema/slice.py index ffd1f2768..624731024 100644 --- a/libs/labelbox/src/labelbox/schema/slice.py +++ b/libs/labelbox/src/labelbox/schema/slice.py @@ -4,7 +4,10 @@ from labelbox.orm.db_object import DbObject, experimental from labelbox.orm.model import Field from labelbox.pagination import PaginatedCollection -from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params +from labelbox.schema.export_params import ( + CatalogExportParams, + validate_catalog_export_params, +) from labelbox.schema.export_task import ExportTask from labelbox.schema.identifiable import GlobalKey, UniqueId from labelbox.schema.task import Task @@ -41,7 +44,7 @@ def __init__(self, id: str, global_key: Optional[str]): def to_hash(self): return { "id": self.id.key, - "global_key": self.global_key.key if self.global_key else None + "global_key": self.global_key.key if self.global_key else None, } @@ -81,10 +84,11 @@ def get_data_row_ids(self) -> PaginatedCollection: return PaginatedCollection( client=self.client, query=query_str, - params={'id': str(self.uid)}, - dereferencing=['getDataRowIdsBySavedQuery', 'nodes'], + params={"id": str(self.uid)}, + dereferencing=["getDataRowIdsBySavedQuery", "nodes"], obj_class=lambda _, data_row_id: data_row_id, - cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor']) + cursor_path=["getDataRowIdsBySavedQuery", "pageInfo", "endCursor"], + ) def get_data_row_identifiers(self) -> PaginatedCollection: """ @@ -116,18 +120,24 @@ def get_data_row_identifiers(self) -> PaginatedCollection: return PaginatedCollection( client=self.client, query=query_str, - params={'id': str(self.uid)}, - dereferencing=['getDataRowIdentifiersBySavedQuery', 'nodes'], + params={"id": str(self.uid)}, + dereferencing=["getDataRowIdentifiersBySavedQuery", "nodes"], obj_class=lambda _, data_row_id_and_gk: Slice.DataRowIdAndGlobalKey( - data_row_id_and_gk.get('id'), - data_row_id_and_gk.get('globalKey', None)), + data_row_id_and_gk.get("id"), + data_row_id_and_gk.get("globalKey", None), + ), cursor_path=[ - 'getDataRowIdentifiersBySavedQuery', 'pageInfo', 'endCursor' - ]) + "getDataRowIdentifiersBySavedQuery", + "pageInfo", + "endCursor", + ], + ) - def export(self, - task_name: Optional[str] = None, - params: Optional[CatalogExportParams] = None) -> ExportTask: + def export( + self, + task_name: Optional[str] = None, + params: Optional[CatalogExportParams] = None, + ) -> ExportTask: """ Creates a slice export task with the given params and returns the task. >>> slice = client.get_catalog_slice("SLICE_ID") @@ -155,7 +165,7 @@ def export_v2( >>> task.result """ task, is_streamable = self._export(task_name, params) - if (is_streamable): + if is_streamable: return ExportTask(task, True) return task @@ -165,73 +175,70 @@ def _export( params: Optional[CatalogExportParams] = None, streamable: bool = False, ) -> Tuple[Task, bool]: - _params = params or CatalogExportParams({ - "attachments": False, - "embeddings": False, - "metadata_fields": False, - "data_row_details": False, - "project_details": False, - "performance_details": False, - "label_details": False, - "media_type_override": None, - "model_run_ids": None, - "project_ids": None, - "interpolated_frames": False, - "all_projects": False, - "all_model_runs": False, - }) + _params = params or CatalogExportParams( + { + "attachments": False, + "embeddings": False, + "metadata_fields": False, + "data_row_details": False, + "project_details": False, + "performance_details": False, + "label_details": False, + "media_type_override": None, + "model_run_ids": None, + "project_ids": None, + "interpolated_frames": False, + "all_projects": False, + "all_model_runs": False, + } + ) validate_catalog_export_params(_params) mutation_name = "exportDataRowsInSlice" create_task_query_str = ( f"mutation {mutation_name}PyApi" f"($input: ExportDataRowsInSliceInput!)" - f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}") + f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}" + ) - media_type_override = _params.get('media_type_override', None) + media_type_override = _params.get("media_type_override", None) query_params = { "input": { "taskName": task_name, - "filters": { - "sliceId": self.uid - }, + "filters": {"sliceId": self.uid}, "isStreamableReady": True, "params": { - "mediaTypeOverride": - media_type_override.value - if media_type_override is not None else None, - "includeAttachments": - _params.get('attachments', False), - "includeEmbeddings": - _params.get('embeddings', False), - "includeMetadata": - _params.get('metadata_fields', False), - "includeDataRowDetails": - _params.get('data_row_details', False), - "includeProjectDetails": - _params.get('project_details', False), - "includePerformanceDetails": - _params.get('performance_details', False), - "includeLabelDetails": - _params.get('label_details', False), - "includeInterpolatedFrames": - _params.get('interpolated_frames', False), - "projectIds": - _params.get('project_ids', None), - "modelRunIds": - _params.get('model_run_ids', None), - "allProjects": - _params.get('all_projects', False), - "allModelRuns": - _params.get('all_model_runs', False), + "mediaTypeOverride": media_type_override.value + if media_type_override is not None + else None, + "includeAttachments": _params.get("attachments", False), + "includeEmbeddings": _params.get("embeddings", False), + "includeMetadata": _params.get("metadata_fields", False), + "includeDataRowDetails": _params.get( + "data_row_details", False + ), + "includeProjectDetails": _params.get( + "project_details", False + ), + "includePerformanceDetails": _params.get( + "performance_details", False + ), + "includeLabelDetails": _params.get("label_details", False), + "includeInterpolatedFrames": _params.get( + "interpolated_frames", False + ), + "projectIds": _params.get("project_ids", None), + "modelRunIds": _params.get("model_run_ids", None), + "allProjects": _params.get("all_projects", False), + "allModelRuns": _params.get("all_model_runs", False), }, "streamable": streamable, } } - res = self.client.execute(create_task_query_str, - query_params, - error_log_key="errors") + res = self.client.execute( + create_task_query_str, query_params, error_log_key="errors" + ) res = res[mutation_name] task_id = res["taskId"] is_streamable = res["isStreamable"] @@ -284,20 +291,21 @@ def get_data_row_ids(self, model_run_id: str) -> PaginatedCollection: return PaginatedCollection( client=self.client, query=ModelSlice.query_str(), - params={ - 'id': str(self.uid), - 'modelRunId': model_run_id - }, - dereferencing=['getDataRowIdentifiersBySavedModelQuery', 'nodes'], - obj_class=lambda _, data_row_id_and_gk: data_row_id_and_gk.get('id' - ), + params={"id": str(self.uid), "modelRunId": model_run_id}, + dereferencing=["getDataRowIdentifiersBySavedModelQuery", "nodes"], + obj_class=lambda _, data_row_id_and_gk: data_row_id_and_gk.get( + "id" + ), cursor_path=[ - 'getDataRowIdentifiersBySavedModelQuery', 'pageInfo', - 'endCursor' - ]) + "getDataRowIdentifiersBySavedModelQuery", + "pageInfo", + "endCursor", + ], + ) - def get_data_row_identifiers(self, - model_run_id: str) -> PaginatedCollection: + def get_data_row_identifiers( + self, model_run_id: str + ) -> PaginatedCollection: """ Fetches all data row ids and global keys (where defined) that match this Slice @@ -310,15 +318,15 @@ def get_data_row_identifiers(self, return PaginatedCollection( client=self.client, query=ModelSlice.query_str(), - params={ - 'id': str(self.uid), - 'modelRunId': model_run_id - }, - dereferencing=['getDataRowIdentifiersBySavedModelQuery', 'nodes'], + params={"id": str(self.uid), "modelRunId": model_run_id}, + dereferencing=["getDataRowIdentifiersBySavedModelQuery", "nodes"], obj_class=lambda _, data_row_id_and_gk: Slice.DataRowIdAndGlobalKey( - data_row_id_and_gk.get('id'), - data_row_id_and_gk.get('globalKey', None)), + data_row_id_and_gk.get("id"), + data_row_id_and_gk.get("globalKey", None), + ), cursor_path=[ - 'getDataRowIdentifiersBySavedModelQuery', 'pageInfo', - 'endCursor' - ]) + "getDataRowIdentifiersBySavedModelQuery", + "pageInfo", + "endCursor", + ], + ) diff --git a/libs/labelbox/src/labelbox/schema/task.py b/libs/labelbox/src/labelbox/schema/task.py index 19d27c325..9d7a26e1d 100644 --- a/libs/labelbox/src/labelbox/schema/task.py +++ b/libs/labelbox/src/labelbox/schema/task.py @@ -11,7 +11,8 @@ from labelbox.pagination import PaginatedCollection from labelbox.schema.internal.datarow_upload_constants import ( - DOWNLOAD_RESULT_PAGE_SIZE,) + DOWNLOAD_RESULT_PAGE_SIZE, +) if TYPE_CHECKING: from labelbox import User @@ -25,7 +26,7 @@ def lru_cache() -> Callable[..., Callable[..., Dict[str, Any]]]: class Task(DbObject): - """ Represents a server-side process that might take a longer time to process. + """Represents a server-side process that might take a longer time to process. Allows the Task state to be updated and checked on the client side. Attributes: @@ -38,6 +39,7 @@ class Task(DbObject): created_by (Relationship): `ToOne` relationship to User organization (Relationship): `ToOne` relationship to Organization """ + updated_at = Field.DateTime("updated_at") created_at = Field.DateTime("created_at") name = Field.String("name") @@ -54,18 +56,21 @@ class Task(DbObject): organization = Relationship.ToOne("Organization") def __eq__(self, task): - return isinstance( - task, Task) and task.uid == self.uid and task.type == self.type + return ( + isinstance(task, Task) + and task.uid == self.uid + and task.type == self.type + ) def __hash__(self): return hash(self.uid) # Import and upsert have several instances of special casing def is_creation_task(self) -> bool: - return self.name == 'JSON Import' or self.type == 'adv-upsert-data-rows' + return self.name == "JSON Import" or self.type == "adv-upsert-data-rows" def refresh(self) -> None: - """ Refreshes Task data from the server. """ + """Refreshes Task data from the server.""" assert self._user is not None tasks = list(self._user.created_tasks(where=Task.uid == self.uid)) if len(tasks) != 1: @@ -84,24 +89,25 @@ def has_errors(self) -> bool: return bool(self.failed_data_rows) return self.status == "FAILED" - def wait_until_done(self, - timeout_seconds: float = 300.0, - check_frequency: float = 2.0) -> None: + def wait_until_done( + self, timeout_seconds: float = 300.0, check_frequency: float = 2.0 + ) -> None: self.wait_till_done(timeout_seconds, check_frequency) - def wait_till_done(self, - timeout_seconds: float = 300.0, - check_frequency: float = 2.0) -> None: - """ Waits until the task is completed. Periodically queries the server - to update the task attributes. + def wait_till_done( + self, timeout_seconds: float = 300.0, check_frequency: float = 2.0 + ) -> None: + """Waits until the task is completed. Periodically queries the server + to update the task attributes. - Args: - timeout_seconds (float): Maximum time this method can block, in seconds. Defaults to five minutes. - check_frequency (float): Frequency of queries to server to update the task attributes, in seconds. Defaults to two seconds. Minimal value is two seconds. - """ + Args: + timeout_seconds (float): Maximum time this method can block, in seconds. Defaults to five minutes. + check_frequency (float): Frequency of queries to server to update the task attributes, in seconds. Defaults to two seconds. Minimal value is two seconds. + """ if check_frequency < 2.0: raise ValueError( - "Expected check frequency to be two seconds or more") + "Expected check frequency to be two seconds or more" + ) while timeout_seconds > 0: if self.status != "IN_PROGRESS": if self.has_errors(): @@ -109,16 +115,16 @@ def wait_till_done(self, "There are errors present. Please look at `task.errors` for more details" ) return - logger.debug("Task.wait_till_done sleeping for %d seconds", - check_frequency) + logger.debug( + "Task.wait_till_done sleeping for %d seconds", check_frequency + ) time.sleep(check_frequency) timeout_seconds -= check_frequency self.refresh() @property def errors(self) -> Optional[Dict[str, Any]]: - """ Fetch the error associated with an import task. - """ + """Fetch the error associated with an import task.""" if self.is_creation_task(): if self.status == "FAILED": result = self._fetch_remote_json() @@ -126,10 +132,12 @@ def errors(self) -> Optional[Dict[str, Any]]: elif self.status == "COMPLETE": return self.failed_data_rows elif self.type == "export-data-rows": - return self._fetch_remote_json(remote_json_field='errors_url') - elif (self.type == "add-data-rows-to-batch" or - self.type == "send-to-task-queue" or - self.type == "send-to-annotate"): + return self._fetch_remote_json(remote_json_field="errors_url") + elif ( + self.type == "add-data-rows-to-batch" + or self.type == "send-to-task-queue" + or self.type == "send-to-annotate" + ): if self.status == "FAILED": # for these tasks, the error is embedded in the result itself return json.loads(self.result_url) @@ -137,26 +145,27 @@ def errors(self) -> Optional[Dict[str, Any]]: @property def result(self) -> Union[List[Dict[str, Any]], Dict[str, Any]]: - """ Fetch the result for an import task. - """ + """Fetch the result for an import task.""" if self.status == "FAILED": raise ValueError(f"Job failed. Errors : {self.errors}") else: result = self._fetch_remote_json() - if self.type == 'export-data-rows': + if self.type == "export-data-rows": return result - return [{ - 'id': data_row['id'], - 'external_id': data_row.get('externalId'), - 'row_data': data_row['rowData'], - 'global_key': data_row.get('globalKey'), - } for data_row in result['createdDataRows']] + return [ + { + "id": data_row["id"], + "external_id": data_row.get("externalId"), + "row_data": data_row["rowData"], + "global_key": data_row.get("globalKey"), + } + for data_row in result["createdDataRows"] + ] @property def failed_data_rows(self) -> Optional[Dict[str, Any]]: - """ Fetch data rows which failed to be created for an import task. - """ + """Fetch data rows which failed to be created for an import task.""" result = self._fetch_remote_json() if len(result.get("errors", [])) > 0: return result["errors"] @@ -165,8 +174,7 @@ def failed_data_rows(self) -> Optional[Dict[str, Any]]: @property def created_data_rows(self) -> Optional[Dict[str, Any]]: - """ Fetch data rows which successfully created for an import task. - """ + """Fetch data rows which successfully created for an import task.""" result = self._fetch_remote_json() if len(result.get("createdDataRows", [])) > 0: return result["createdDataRows"] @@ -174,23 +182,22 @@ def created_data_rows(self) -> Optional[Dict[str, Any]]: return None @lru_cache() - def _fetch_remote_json(self, - remote_json_field: Optional[str] = None - ) -> Dict[str, Any]: - """ Function for fetching and caching the result data. - """ + def _fetch_remote_json( + self, remote_json_field: Optional[str] = None + ) -> Dict[str, Any]: + """Function for fetching and caching the result data.""" def download_result(remote_json_field: Optional[str], format: str): - url = getattr(self, remote_json_field or 'result_url') + url = getattr(self, remote_json_field or "result_url") if url is None: return None response = requests.get(url) response.raise_for_status() - if format == 'json': + if format == "json": return response.json() - elif format == 'ndjson': + elif format == "ndjson": return parser.loads(response.text) else: raise ValueError( @@ -198,9 +205,9 @@ def download_result(remote_json_field: Optional[str], format: str): ) if self.is_creation_task(): - format = 'json' - elif self.type == 'export-data-rows': - format = 'ndjson' + format = "json" + elif self.type == "export-data-rows": + format = "ndjson" else: raise ValueError( "Task result is only supported for `JSON Import` and `export` tasks." @@ -221,7 +228,8 @@ def download_result(remote_json_field: Optional[str], format: str): def get_task(client, task_id): user: User = client.get_user() tasks: List[Task] = list( - user.created_tasks(where=Entity.Task.uid == task_id)) + user.created_tasks(where=Entity.Task.uid == task_id) + ) # Cache user in a private variable as the relationship can't be # resolved due to server-side limitations (see Task.created_by) # for more info. @@ -261,12 +269,14 @@ def errors(self) -> Optional[List[Dict[str, Any]]]: # type: ignore @property def created_data_rows( # type: ignore - self) -> Optional[List[Dict[str, Any]]]: + self, + ) -> Optional[List[Dict[str, Any]]]: return self.result @property def failed_data_rows( # type: ignore - self) -> Optional[List[Dict[str, Any]]]: + self, + ) -> Optional[List[Dict[str, Any]]]: return self.errors def _download_results_paginated(self) -> PaginatedCollection: @@ -289,23 +299,23 @@ def _download_results_paginated(self) -> PaginatedCollection: """ params = { - 'taskId': self.uid, - 'first': page_size, - 'from': from_cursor, + "taskId": self.uid, + "first": page_size, + "from": from_cursor, } return PaginatedCollection( client=self.client, query=query_str, params=params, - dereferencing=['successesfulDataRowImports', 'nodes'], + dereferencing=["successesfulDataRowImports", "nodes"], obj_class=lambda _, data_row: { - 'id': data_row.get('id'), - 'external_id': data_row.get('externalId'), - 'row_data': data_row.get('rowData'), - 'global_key': data_row.get('globalKey'), + "id": data_row.get("id"), + "external_id": data_row.get("externalId"), + "row_data": data_row.get("rowData"), + "global_key": data_row.get("globalKey"), }, - cursor_path=['successesfulDataRowImports', 'after'], + cursor_path=["successesfulDataRowImports", "after"], ) def _download_errors_paginated(self) -> PaginatedCollection: @@ -340,32 +350,33 @@ def _download_errors_paginated(self) -> PaginatedCollection: """ params = { - 'taskId': self.uid, - 'first': page_size, - 'from': from_cursor, + "taskId": self.uid, + "first": page_size, + "from": from_cursor, } def convert_errors_to_legacy_format(client, data_row): - spec = data_row.get('spec', {}) + spec = data_row.get("spec", {}) return { - 'message': - data_row.get('message'), - 'failedDataRows': [{ - 'externalId': spec.get('externalId'), - 'rowData': spec.get('rowData'), - 'globalKey': spec.get('globalKey'), - 'metadata': spec.get('metadata', []), - 'attachments': spec.get('attachments', []), - }] + "message": data_row.get("message"), + "failedDataRows": [ + { + "externalId": spec.get("externalId"), + "rowData": spec.get("rowData"), + "globalKey": spec.get("globalKey"), + "metadata": spec.get("metadata", []), + "attachments": spec.get("attachments", []), + } + ], } return PaginatedCollection( client=self.client, query=query_str, params=params, - dereferencing=['failedDataRowImports', 'results'], + dereferencing=["failedDataRowImports", "results"], obj_class=convert_errors_to_legacy_format, - cursor_path=['failedDataRowImports', 'after'], + cursor_path=["failedDataRowImports", "after"], ) def _results_as_list(self) -> Optional[List[Dict[str, Any]]]: diff --git a/libs/labelbox/src/labelbox/schema/user.py b/libs/labelbox/src/labelbox/schema/user.py index 430868b85..f7b3cd0d6 100644 --- a/libs/labelbox/src/labelbox/schema/user.py +++ b/libs/labelbox/src/labelbox/schema/user.py @@ -7,7 +7,7 @@ class User(DbObject): - """ A User is a registered Labelbox user (for example you) associated with + """A User is a registered Labelbox user (for example you) associated with data they create or import and an Organization they belong to. Attributes: @@ -43,7 +43,7 @@ class User(DbObject): org_role = Relationship.ToOne("OrgRole", False) def update_org_role(self, role: "Role") -> None: - """ Updated the `User`s organization role. + """Updated the `User`s organization role. See client.get_roles() to get all valid roles If you a user is converted from project level permissions to org level permissions and then convert back, their permissions will remain for each individual project @@ -58,23 +58,22 @@ def update_org_role(self, role: "Role") -> None: setOrganizationRole(data: {userId: $userId, roleId: $roleId}) { id name }} """ % (user_id_param, role_id_param) - self.client.execute(query_str, { - user_id_param: self.uid, - role_id_param: role.uid - }) + self.client.execute( + query_str, {user_id_param: self.uid, role_id_param: role.uid} + ) def remove_from_project(self, project: "Project") -> None: - """ Removes a User from a project. Only used for project based users. + """Removes a User from a project. Only used for project based users. Project based user means their org role is "NONE" Args: project (Project): Project to remove user from """ - self.upsert_project_role(project, self.client.get_roles()['NONE']) + self.upsert_project_role(project, self.client.get_roles()["NONE"]) def upsert_project_role(self, project: "Project", role: "Role") -> None: - """ Updates or replaces a User's role in a project. + """Updates or replaces a User's role in a project. Args: project (Project): The project to update the users permissions for @@ -82,21 +81,30 @@ def upsert_project_role(self, project: "Project", role: "Role") -> None: """ org_role = self.org_role() - if org_role.name.upper() != 'NONE': + if org_role.name.upper() != "NONE": raise ValueError( - "User is not project based and has access to all projects") + "User is not project based and has access to all projects" + ) project_id_param = "projectId" user_id_param = "userId" role_id_param = "roleId" query_str = """mutation SetProjectMembershipPyApi($%s: ID!, $%s: ID!, $%s: ID!) { setProjectMembership(data: {%s: $userId, roleId: $%s, projectId: $%s}) {id}} - """ % (user_id_param, role_id_param, project_id_param, user_id_param, - role_id_param, project_id_param) + """ % ( + user_id_param, + role_id_param, + project_id_param, + user_id_param, + role_id_param, + project_id_param, + ) self.client.execute( - query_str, { + query_str, + { project_id_param: project.uid, user_id_param: self.uid, - role_id_param: role.uid - }) + role_id_param: role.uid, + }, + ) diff --git a/libs/labelbox/src/labelbox/schema/user_group.py b/libs/labelbox/src/labelbox/schema/user_group.py index 91cdb159c..9d506bf92 100644 --- a/libs/labelbox/src/labelbox/schema/user_group.py +++ b/libs/labelbox/src/labelbox/schema/user_group.py @@ -6,7 +6,11 @@ from labelbox.exceptions import ResourceCreationError from labelbox.schema.user import User from labelbox.schema.project import Project -from labelbox.exceptions import UnprocessableEntityError, MalformedQueryException, ResourceNotFoundError +from labelbox.exceptions import ( + UnprocessableEntityError, + MalformedQueryException, + ResourceNotFoundError, +) from labelbox.schema.queue_mode import QueueMode from labelbox.schema.ontology_kind import EditorTaskType from labelbox.schema.media_type import MediaType @@ -28,6 +32,7 @@ class UserGroupColor(Enum): YELLOW (str): Hex color code for yellow (#E7BF00). GRAY (str): Hex color code for gray (#B8C4D3). """ + BLUE = "9EC5FF" PURPLE = "CEB8FF" ORANGE = "FFB35F" @@ -38,7 +43,7 @@ class UserGroupColor(Enum): YELLOW = "E7BF00" GRAY = "B8C4D3" - + class UserGroup(BaseModel): """ Represents a user group in Labelbox. @@ -59,14 +64,14 @@ class UserGroup(BaseModel): delete(self) -> bool get_user_groups(client: Client) -> Iterator["UserGroup"] """ + id: str name: str color: UserGroupColor users: Set[User] projects: Set[Project] client: Client - model_config = ConfigDict(arbitrary_types_allowed = True) - + model_config = ConfigDict(arbitrary_types_allowed=True) def __init__( self, @@ -75,7 +80,7 @@ def __init__( name: str = "", color: UserGroupColor = UserGroupColor.BLUE, users: Set[User] = set(), - projects: Set[Project] = set() + projects: Set[Project] = set(), ): """ Initializes a UserGroup object. @@ -91,9 +96,18 @@ def __init__( Raises: RuntimeError: If the experimental feature is not enabled in the client. """ - super().__init__(client=client, id=id, name=name, color=color, users=users, projects=projects) + super().__init__( + client=client, + id=id, + name=name, + color=color, + users=users, + projects=projects, + ) if not self.client.enable_experimental: - raise RuntimeError("Please enable experimental in client to use UserGroups") + raise RuntimeError( + "Please enable experimental in client to use UserGroups" + ) def get(self) -> "UserGroup": """ @@ -140,11 +154,17 @@ def get(self) -> "UserGroup": } result = self.client.execute(query, params) if not result: - raise ResourceNotFoundError(message="Failed to get user group as user group does not exist") + raise ResourceNotFoundError( + message="Failed to get user group as user group does not exist" + ) self.name = result["userGroup"]["name"] self.color = UserGroupColor(result["userGroup"]["color"]) - self.projects = self._get_projects_set(result["userGroup"]["projects"]["nodes"]) - self.users = self._get_users_set(result["userGroup"]["members"]["nodes"]) + self.projects = self._get_projects_set( + result["userGroup"]["projects"]["nodes"] + ) + self.users = self._get_users_set( + result["userGroup"]["members"]["nodes"] + ) return self def update(self) -> "UserGroup": @@ -190,23 +210,18 @@ def update(self) -> "UserGroup": } """ params = { - "id": - self.id, - "name": - self.name, - "color": - self.color.value, - "projectIds": [ - project.uid for project in self.projects - ], - "userIds": [ - user.uid for user in self.users - ] + "id": self.id, + "name": self.name, + "color": self.color.value, + "projectIds": [project.uid for project in self.projects], + "userIds": [user.uid for user in self.users], } try: result = self.client.execute(query, params) if not result: - raise ResourceNotFoundError(message="Failed to update user group as user group does not exist") + raise ResourceNotFoundError( + message="Failed to update user group as user group does not exist" + ) except MalformedQueryException as e: raise UnprocessableEntityError("Failed to update user group") from e return self @@ -257,26 +272,22 @@ def create(self) -> "UserGroup": } """ params = { - "name": - self.name, - "color": - self.color.value, - "projectIds": [ - project.uid for project in self.projects - ], - "userIds": [ - user.uid for user in self.users - ] + "name": self.name, + "color": self.color.value, + "projectIds": [project.uid for project in self.projects], + "userIds": [user.uid for user in self.users], } result = None error = None - try: + try: result = self.client.execute(query, params) except Exception as e: error = e if not result or error: # this is client side only, server doesn't have an equivalent error - raise ResourceCreationError(f"Failed to create user group, either user group name is in use currently, or provided user or projects don't exist server error: {error}") + raise ResourceCreationError( + f"Failed to create user group, either user group name is in use currently, or provided user or projects don't exist server error: {error}" + ) result = result["createUserGroup"]["group"] self.id = result["id"] return self @@ -291,7 +302,7 @@ def delete(self) -> bool: Returns: bool: True if the user group was successfully deleted, False otherwise. - + Raises: ResourceNotFoundError: If the deletion of the user group fails due to not existing ValueError: If the group ID is not provided. @@ -308,7 +319,9 @@ def delete(self) -> bool: params = {"id": self.id} result = self.client.execute(query, params) if not result: - raise ResourceNotFoundError(message="Failed to delete user group as user group does not exist") + raise ResourceNotFoundError( + message="Failed to delete user group as user group does not exist" + ) return result["deleteUserGroup"]["success"] def get_user_groups(self) -> Iterator["UserGroup"]: @@ -349,8 +362,9 @@ def get_user_groups(self) -> Iterator["UserGroup"]: """ nextCursor = None while True: - userGroups = self.client.execute( - query, {"after": nextCursor})["userGroups"] + userGroups = self.client.execute(query, {"after": nextCursor})[ + "userGroups" + ] if not userGroups: return yield @@ -361,7 +375,9 @@ def get_user_groups(self) -> Iterator["UserGroup"]: userGroup.name = group["name"] userGroup.color = UserGroupColor(group["color"]) userGroup.users = self._get_users_set(group["members"]["nodes"]) - userGroup.projects = self._get_projects_set(group["projects"]["nodes"]) + userGroup.projects = self._get_projects_set( + group["projects"]["nodes"] + ) yield userGroup nextCursor = userGroups["nextCursor"] if not nextCursor: diff --git a/libs/labelbox/src/labelbox/schema/webhook.py b/libs/labelbox/src/labelbox/schema/webhook.py index 1f1653c52..0eebe157e 100644 --- a/libs/labelbox/src/labelbox/schema/webhook.py +++ b/libs/labelbox/src/labelbox/schema/webhook.py @@ -10,7 +10,7 @@ class Webhook(DbObject, Updateable): - """ Represents a server-side rule for sending notifications to a web-server + """Represents a server-side rule for sending notifications to a web-server whenever one of several predefined actions happens within a context of a Project or an Organization. @@ -53,7 +53,7 @@ class Topic(Enum): @staticmethod def create(client, topics, url, secret, project) -> "Webhook": - """ Creates a Webhook. + """Creates a Webhook. Args: client (Client): The Labelbox client used to connect @@ -84,13 +84,19 @@ def create(client, topics, url, secret, project) -> "Webhook": raise ValueError("URL must be a non-empty string.") Webhook.validate_topics(topics) - project_str = "" if project is None \ - else ("project:{id:\"%s\"}," % project.uid) + project_str = ( + "" if project is None else ('project:{id:"%s"},' % project.uid) + ) query_str = """mutation CreateWebhookPyApi { createWebhook(data:{%s topics:{set:[%s]}, url:"%s", secret:"%s" }){%s} - } """ % (project_str, " ".join(topics), url, secret, - query.results_query_part(Entity.Webhook)) + } """ % ( + project_str, + " ".join(topics), + url, + secret, + query.results_query_part(Entity.Webhook), + ) return Webhook(client, client.execute(query_str)["createWebhook"]) @@ -98,7 +104,8 @@ def create(client, topics, url, secret, project) -> "Webhook": def validate_topics(topics) -> None: if isinstance(topics, str) or not isinstance(topics, Iterable): raise TypeError( - f"Topics must be List[Webhook.Topic]. Found `{topics}`") + f"Topics must be List[Webhook.Topic]. Found `{topics}`" + ) for topic in topics: Webhook.validate_value(topic, Webhook.Topic) @@ -118,7 +125,7 @@ def delete(self) -> None: self.update(status=self.Status.INACTIVE.value) def update(self, topics=None, url=None, status=None): - """ Updates the Webhook. + """Updates the Webhook. Args: topics (Optional[List[Topic]]): The new topics. @@ -137,15 +144,17 @@ def update(self, topics=None, url=None, status=None): if status is not None: self.validate_value(status, self.Status) - topics_str = "" if topics is None \ - else "topics: {set: [%s]}" % " ".join(topics) - url_str = "" if url is None else "url: \"%s\"" % url + topics_str = ( + "" if topics is None else "topics: {set: [%s]}" % " ".join(topics) + ) + url_str = "" if url is None else 'url: "%s"' % url status_str = "" if status is None else "status: %s" % status query_str = """mutation UpdateWebhookPyApi { updateWebhook(where: {id: "%s"} data:{%s}){%s}} """ % ( - self.uid, ", ".join(filter(None, - (topics_str, url_str, status_str))), - query.results_query_part(Entity.Webhook)) + self.uid, + ", ".join(filter(None, (topics_str, url_str, status_str))), + query.results_query_part(Entity.Webhook), + ) self._set_field_values(self.client.execute(query_str)["updateWebhook"]) diff --git a/libs/labelbox/src/labelbox/types.py b/libs/labelbox/src/labelbox/types.py index 98f7042ae..0c0c2904f 100644 --- a/libs/labelbox/src/labelbox/types.py +++ b/libs/labelbox/src/labelbox/types.py @@ -3,4 +3,4 @@ except ImportError: raise ImportError( "There are missing dependencies for `labelbox.types`, use `pip install labelbox[data] --upgrade` to install missing dependencies." - ) \ No newline at end of file + ) diff --git a/libs/labelbox/src/labelbox/typing_imports.py b/libs/labelbox/src/labelbox/typing_imports.py index 2c2716710..6edfb9bef 100644 --- a/libs/labelbox/src/labelbox/typing_imports.py +++ b/libs/labelbox/src/labelbox/typing_imports.py @@ -1,10 +1,11 @@ """ -This module imports types that differ across python versions, so other modules +This module imports types that differ across python versions, so other modules don't have to worry about where they should be imported from. """ import sys + if sys.version_info >= (3, 8): from typing import Literal else: - from typing_extensions import Literal \ No newline at end of file + from typing_extensions import Literal diff --git a/libs/labelbox/src/labelbox/utils.py b/libs/labelbox/src/labelbox/utils.py index 21f0c338b..c76ce188f 100644 --- a/libs/labelbox/src/labelbox/utils.py +++ b/libs/labelbox/src/labelbox/utils.py @@ -6,11 +6,17 @@ from dateutil.utils import default_tzinfo from urllib.parse import urlparse -from pydantic import BaseModel, ConfigDict, model_serializer, AliasGenerator, AliasChoices +from pydantic import ( + BaseModel, + ConfigDict, + model_serializer, + AliasGenerator, + AliasChoices, +) from pydantic.alias_generators import to_camel, to_pascal -UPPERCASE_COMPONENTS = ['uri', 'rgb'] -ISO_DATETIME_FORMAT = '%Y-%m-%dT%H:%M:%SZ' +UPPERCASE_COMPONENTS = ["uri", "rgb"] +ISO_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" DFLT_TZ = tzoffset("UTC", 0000) @@ -26,22 +32,22 @@ def _convert(s, sep, title): def camel_case(s): - """ Converts a string in [snake|camel|title]case to camelCase. """ + """Converts a string in [snake|camel|title]case to camelCase.""" return _convert(s, "", lambda i: i > 0) def title_case(s): - """ Converts a string in [snake|camel|title]case to TitleCase. """ + """Converts a string in [snake|camel|title]case to TitleCase.""" return _convert(s, "", lambda i: True) def snake_case(s): - """ Converts a string in [snake|camel|title]case to snake_case. """ + """Converts a string in [snake|camel|title]case to snake_case.""" return _convert(s, "_", lambda i: False) def sentence_case(s: str) -> str: - """ Converts a string in [snake|camel|title]case to Sentence case. """ + """Converts a string in [snake|camel|title]case to Sentence case.""" # Replace underscores with spaces and convert to lower case sentence_str = s.replace("_", " ").lower() # Capitalize the first letter of each word @@ -62,7 +68,11 @@ def is_valid_uri(uri): class _CamelCaseMixin(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed = True, alias_generator = to_camel, populate_by_name = True) + model_config = ConfigDict( + arbitrary_types_allowed=True, + alias_generator=to_camel, + populate_by_name=True, + ) class _NoCoercionMixin: @@ -72,7 +82,7 @@ class _NoCoercionMixin: uninteded behavior. This mixin uses a class_name discriminator field to prevent pydantic from - corecing the type of the object. Add a class_name field to the class you + corecing the type of the object. Add a class_name field to the class you want to discrimniate and use this mixin class to remove the discriminator when serializing the object. @@ -81,10 +91,11 @@ class ConversationData(BaseData, _NoCoercionMixin): class_name: Literal["ConversationData"] = "ConversationData" """ + @model_serializer(mode="wrap") def serialize_model(self, handler): res = handler(self) - res.pop('class_name') + res.pop("class_name") return res diff --git a/libs/labelbox/tests/conftest.py b/libs/labelbox/tests/conftest.py index 4251ac698..446db396b 100644 --- a/libs/labelbox/tests/conftest.py +++ b/libs/labelbox/tests/conftest.py @@ -53,29 +53,30 @@ @pytest.fixture(scope="session") def rand_gen(): - def gen(field_type): if field_type is str: - return "".join(ascii_letters[randint(0, - len(ascii_letters) - 1)] - for _ in range(16)) + return "".join( + ascii_letters[randint(0, len(ascii_letters) - 1)] + for _ in range(16) + ) if field_type is datetime: return datetime.now() - raise Exception("Can't random generate for field type '%r'" % - field_type) + raise Exception( + "Can't random generate for field type '%r'" % field_type + ) return gen class Environ(Enum): - LOCAL = 'local' - PROD = 'prod' - STAGING = 'staging' - CUSTOM = 'custom' - STAGING_EU = 'staging-eu' - EPHEMERAL = 'ephemeral' # Used for testing PRs with ephemeral environments + LOCAL = "local" + PROD = "prod" + STAGING = "staging" + CUSTOM = "custom" + STAGING_EU = "staging-eu" + EPHEMERAL = "ephemeral" # Used for testing PRs with ephemeral environments @pytest.fixture @@ -89,48 +90,50 @@ def external_id() -> str: def ephemeral_endpoint() -> str: - return os.getenv('LABELBOX_TEST_BASE_URL', EPHEMERAL_BASE_URL) + return os.getenv("LABELBOX_TEST_BASE_URL", EPHEMERAL_BASE_URL) def graphql_url(environ: str) -> str: if environ == Environ.LOCAL: - return 'http://localhost:3000/api/graphql' + return "http://localhost:3000/api/graphql" elif environ == Environ.PROD: - return 'https://api.labelbox.com/graphql' + return "https://api.labelbox.com/graphql" elif environ == Environ.STAGING: - return 'https://api.lb-stage.xyz/graphql' + return "https://api.lb-stage.xyz/graphql" elif environ == Environ.CUSTOM: graphql_api_endpoint = os.environ.get( - 'LABELBOX_TEST_GRAPHQL_API_ENDPOINT') + "LABELBOX_TEST_GRAPHQL_API_ENDPOINT" + ) if graphql_api_endpoint is None: raise Exception("Missing LABELBOX_TEST_GRAPHQL_API_ENDPOINT") return graphql_api_endpoint elif environ == Environ.EPHEMERAL: return f"{ephemeral_endpoint()}/graphql" - return 'http://host.docker.internal:8080/graphql' + return "http://host.docker.internal:8080/graphql" def rest_url(environ: str) -> str: if environ == Environ.LOCAL: - return 'http://localhost:3000/api/v1' + return "http://localhost:3000/api/v1" elif environ == Environ.PROD: - return 'https://api.labelbox.com/api/v1' + return "https://api.labelbox.com/api/v1" elif environ == Environ.STAGING: - return 'https://api.lb-stage.xyz/api/v1' + return "https://api.lb-stage.xyz/api/v1" elif environ == Environ.CUSTOM: - rest_api_endpoint = os.environ.get('LABELBOX_TEST_REST_API_ENDPOINT') + rest_api_endpoint = os.environ.get("LABELBOX_TEST_REST_API_ENDPOINT") if rest_api_endpoint is None: raise Exception("Missing LABELBOX_TEST_REST_API_ENDPOINT") return rest_api_endpoint elif environ == Environ.EPHEMERAL: return f"{ephemeral_endpoint()}/api/v1" - return 'http://host.docker.internal:8080/api/v1' + return "http://host.docker.internal:8080/api/v1" def testing_api_key(environ: Environ) -> str: keys = [ f"LABELBOX_TEST_API_KEY_{environ.value.upper()}", - "LABELBOX_TEST_API_KEY", "LABELBOX_API_KEY" + "LABELBOX_TEST_API_KEY", + "LABELBOX_API_KEY", ] for key in keys: value = os.environ.get(key) @@ -143,47 +146,51 @@ def service_api_key() -> str: service_api_key = os.environ["SERVICE_API_KEY"] if service_api_key is None: raise Exception( - "SERVICE_API_KEY is missing and needed for admin client") + "SERVICE_API_KEY is missing and needed for admin client" + ) return service_api_key class IntegrationClient(Client): - def __init__(self, environ: str) -> None: api_url = graphql_url(environ) api_key = testing_api_key(environ) rest_endpoint = rest_url(environ) - super().__init__(api_key, - api_url, - enable_experimental=True, - rest_endpoint=rest_endpoint) + super().__init__( + api_key, + api_url, + enable_experimental=True, + rest_endpoint=rest_endpoint, + ) self.queries = [] def execute(self, query=None, params=None, check_naming=True, **kwargs): if check_naming and query is not None: - assert re.match(r"\s*(?:query|mutation) \w+PyApi", - query) is not None + assert ( + re.match(r"\s*(?:query|mutation) \w+PyApi", query) is not None + ) self.queries.append((query, params)) - if not kwargs.get('timeout'): - kwargs['timeout'] = 30.0 + if not kwargs.get("timeout"): + kwargs["timeout"] = 30.0 return super().execute(query, params, **kwargs) class AdminClient(Client): - def __init__(self, env): """ - The admin client creates organizations and users using admin api described here https://labelbox.atlassian.net/wiki/spaces/AP/pages/2206564433/Internal+Admin+APIs. + The admin client creates organizations and users using admin api described here https://labelbox.atlassian.net/wiki/spaces/AP/pages/2206564433/Internal+Admin+APIs. """ self._api_key = service_api_key() self._admin_endpoint = f"{ephemeral_endpoint()}/admin/v1" self._api_url = graphql_url(env) self._rest_endpoint = rest_url(env) - super().__init__(self._api_key, - self._api_url, - enable_experimental=True, - rest_endpoint=self._rest_endpoint) + super().__init__( + self._api_key, + self._api_url, + enable_experimental=True, + rest_endpoint=self._rest_endpoint, + ) def _create_organization(self) -> str: endpoint = f"{self._admin_endpoint}/organizations/" @@ -195,12 +202,14 @@ def _create_organization(self) -> str: data = response.json() if response.status_code not in [ - requests.codes.created, requests.codes.ok + requests.codes.created, + requests.codes.ok, ]: - raise Exception("Failed to create org, message: " + - str(data['message'])) + raise Exception( + "Failed to create org, message: " + str(data["message"]) + ) - return data['id'] + return data["id"] def _create_user(self, organization_id=None) -> Tuple[str, str]: if organization_id is None: @@ -221,31 +230,35 @@ def _create_user(self, organization_id=None) -> Tuple[str, str]: ) data = response.json() if response.status_code not in [ - requests.codes.created, requests.codes.ok + requests.codes.created, + requests.codes.ok, ]: - raise Exception("Failed to create user, message: " + - str(data['message'])) + raise Exception( + "Failed to create user, message: " + str(data["message"]) + ) - user_identity_id = data['identityId'] + user_identity_id = data["identityId"] - endpoint = f"{self._admin_endpoint}/organizations/{organization_id}/users/" + endpoint = ( + f"{self._admin_endpoint}/organizations/{organization_id}/users/" + ) response = requests.post( endpoint, headers=self.headers, - json={ - "identityId": user_identity_id, - "organizationRole": "Admin" - }, + json={"identityId": user_identity_id, "organizationRole": "Admin"}, ) data = response.json() if response.status_code not in [ - requests.codes.created, requests.codes.ok + requests.codes.created, + requests.codes.ok, ]: - raise Exception("Failed to create link user to org, message: " + - str(data['message'])) + raise Exception( + "Failed to create link user to org, message: " + + str(data["message"]) + ) - user_id = data['id'] + user_id = data["id"] endpoint = f"{self._admin_endpoint}/users/{user_id}/token" response = requests.get( @@ -254,10 +267,13 @@ def _create_user(self, organization_id=None) -> Tuple[str, str]: ) data = response.json() if response.status_code not in [ - requests.codes.created, requests.codes.ok + requests.codes.created, + requests.codes.ok, ]: - raise Exception("Failed to create ephemeral user, message: " + - str(data['message'])) + raise Exception( + "Failed to create ephemeral user, message: " + + str(data["message"]) + ) token = data["token"] @@ -282,17 +298,18 @@ def create_api_key_for_user(self) -> str: class EphemeralClient(Client): - def __init__(self, environ=Environ.EPHEMERAL): self.admin_client = AdminClient(environ) self.api_key = self.admin_client.create_api_key_for_user() api_url = graphql_url(environ) rest_endpoint = rest_url(environ) - super().__init__(self.api_key, - api_url, - enable_experimental=True, - rest_endpoint=rest_endpoint) + super().__init__( + self.api_key, + api_url, + enable_experimental=True, + rest_endpoint=rest_endpoint, + ) @pytest.fixture @@ -322,7 +339,7 @@ def environ() -> Environ: value = os.environ.get(key) if value is not None: return Environ(value) - raise Exception(f'Missing env key in: {os.environ}') + raise Exception(f"Missing env key in: {os.environ}") def cancel_invite(client, invite_id): @@ -331,7 +348,7 @@ def cancel_invite(client, invite_id): """ query_str = """mutation CancelInvitePyApi($where: WhereUniqueIdInput!) { cancelInvite(where: $where) {id}}""" - client.execute(query_str, {'where': {'id': invite_id}}, experimental=True) + client.execute(query_str, {"where": {"id": invite_id}}, experimental=True) def get_project_invites(client, project_id): @@ -344,11 +361,14 @@ def get_project_invites(client, project_id): invites(from: $from, first: $first) { nodes { %s projectInvites { projectId projectRoleName } } nextCursor}}} """ % (id_param, id_param, query.results_query_part(Invite)) - return PaginatedCollection(client, - query_str, {id_param: project_id}, - ['project', 'invites', 'nodes'], - Invite, - cursor_path=['project', 'invites', 'nextCursor']) + return PaginatedCollection( + client, + query_str, + {id_param: project_id}, + ["project", "invites", "nodes"], + Invite, + cursor_path=["project", "invites", "nextCursor"], + ) def get_invites(client): @@ -360,18 +380,23 @@ def get_invites(client): nodes { id createdAt organizationRoleName inviteeEmail } nextCursor }}}""" invites = PaginatedCollection( client, - query_str, {}, ['organization', 'invites', 'nodes'], + query_str, + {}, + ["organization", "invites", "nodes"], Invite, - cursor_path=['organization', 'invites', 'nextCursor'], - experimental=True) + cursor_path=["organization", "invites", "nextCursor"], + experimental=True, + ) return invites @pytest.fixture def queries(): - return SimpleNamespace(cancel_invite=cancel_invite, - get_project_invites=get_project_invites, - get_invites=get_invites) + return SimpleNamespace( + cancel_invite=cancel_invite, + get_project_invites=get_project_invites, + get_invites=get_invites, + ) @pytest.fixture(scope="session") @@ -388,52 +413,57 @@ def client(environ: str): @pytest.fixture(scope="session") def pdf_url(client): - pdf_url = client.upload_file('tests/assets/loremipsum.pdf') - return {"row_data": {"pdf_url": pdf_url,}, "global_key": str(uuid.uuid4())} + pdf_url = client.upload_file("tests/assets/loremipsum.pdf") + return { + "row_data": { + "pdf_url": pdf_url, + }, + "global_key": str(uuid.uuid4()), + } @pytest.fixture(scope="session") def pdf_entity_data_row(client): pdf_url = client.upload_file( - 'tests/assets/arxiv-pdf_data_99-word-token-pdfs_0801.3483.pdf') + "tests/assets/arxiv-pdf_data_99-word-token-pdfs_0801.3483.pdf" + ) text_layer_url = client.upload_file( - 'tests/assets/arxiv-pdf_data_99-word-token-pdfs_0801.3483-lb-textlayer.json' + "tests/assets/arxiv-pdf_data_99-word-token-pdfs_0801.3483-lb-textlayer.json" ) return { - "row_data": { - "pdf_url": pdf_url, - "text_layer_url": text_layer_url - }, - "global_key": str(uuid.uuid4()) + "row_data": {"pdf_url": pdf_url, "text_layer_url": text_layer_url}, + "global_key": str(uuid.uuid4()), } @pytest.fixture() def conversation_entity_data_row(client, rand_gen): return { - "row_data": - "https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json", - "global_key": - f"https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json-{rand_gen(str)}", + "row_data": "https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json", + "global_key": f"https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json-{rand_gen(str)}", } @pytest.fixture def project(client, rand_gen): - project = client.create_project(name=rand_gen(str), - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) + project = client.create_project( + name=rand_gen(str), + queue_mode=QueueMode.Batch, + media_type=MediaType.Image, + ) yield project project.delete() @pytest.fixture def consensus_project(client, rand_gen): - project = client.create_project(name=rand_gen(str), - quality_mode=QualityMode.Consensus, - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) + project = client.create_project( + name=rand_gen(str), + quality_mode=QualityMode.Consensus, + queue_mode=QueueMode.Batch, + media_type=MediaType.Image, + ) yield project project.delete() @@ -443,23 +473,24 @@ def model_config(client, rand_gen, valid_model_id): model_config = client.create_model_config( name=rand_gen(str), model_id=valid_model_id, - inference_params={"param": "value"}) + inference_params={"param": "value"}, + ) yield model_config client.delete_model_config(model_config.uid) @pytest.fixture -def consensus_project_with_batch(consensus_project, initial_dataset, rand_gen, - image_url): +def consensus_project_with_batch( + consensus_project, initial_dataset, rand_gen, image_url +): project = consensus_project dataset = initial_dataset data_rows = [] for _ in range(3): - data_rows.append({ - DataRow.row_data: image_url, - DataRow.global_key: str(uuid.uuid4()) - }) + data_rows.append( + {DataRow.row_data: image_url, DataRow.global_key: str(uuid.uuid4())} + ) task = dataset.create_data_rows(data_rows) task.wait_till_done() assert task.status == "COMPLETE" @@ -469,7 +500,7 @@ def consensus_project_with_batch(consensus_project, initial_dataset, rand_gen, batch = project.create_batch( rand_gen(str), data_rows, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) + 5, # priority between 1(Highest) - 5(lowest) ) yield [project, batch, data_rows] @@ -483,7 +514,7 @@ def dataset(client, rand_gen): dataset.delete() -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def unique_dataset(client, rand_gen): dataset = client.create_dataset(name=rand_gen(str)) yield dataset @@ -492,12 +523,12 @@ def unique_dataset(client, rand_gen): @pytest.fixture def small_dataset(dataset: Dataset): - task = dataset.create_data_rows([ - { - "row_data": SMALL_DATASET_URL, - "external_id": "my-image" - }, - ] * 2) + task = dataset.create_data_rows( + [ + {"row_data": SMALL_DATASET_URL, "external_id": "my-image"}, + ] + * 2 + ) task.wait_till_done() yield dataset @@ -506,13 +537,15 @@ def small_dataset(dataset: Dataset): @pytest.fixture def data_row(dataset, image_url, rand_gen): global_key = f"global-key-{rand_gen(str)}" - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image", - "global_key": global_key - }, - ]) + task = dataset.create_data_rows( + [ + { + "row_data": image_url, + "external_id": "my-image", + "global_key": global_key, + }, + ] + ) task.wait_till_done() dr = dataset.data_rows().get_one() yield dr @@ -522,13 +555,15 @@ def data_row(dataset, image_url, rand_gen): @pytest.fixture def data_row_and_global_key(dataset, image_url, rand_gen): global_key = f"global-key-{rand_gen(str)}" - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image", - "global_key": global_key - }, - ]) + task = dataset.create_data_rows( + [ + { + "row_data": image_url, + "external_id": "my-image", + "global_key": global_key, + }, + ] + ) task.wait_till_done() dr = dataset.data_rows().get_one() yield dr, global_key @@ -539,10 +574,11 @@ def data_row_and_global_key(dataset, image_url, rand_gen): # @pytest.mark.parametrize('data_rows', [], indirect=True) # if omitted, count defaults to 1 @pytest.fixture -def data_rows(dataset, image_url, request, wait_for_data_row_processing, - client): +def data_rows( + dataset, image_url, request, wait_for_data_row_processing, client +): count = 1 - if hasattr(request, 'param'): + if hasattr(request, "param"): count = request.param datarows = [ @@ -565,26 +601,26 @@ def data_rows(dataset, image_url, request, wait_for_data_row_processing, @pytest.fixture def iframe_url(environ) -> str: if environ in [Environ.PROD, Environ.LOCAL]: - return 'https://editor.labelbox.com' + return "https://editor.labelbox.com" elif environ == Environ.STAGING: - return 'https://editor.lb-stage.xyz' + return "https://editor.lb-stage.xyz" @pytest.fixture def sample_image() -> str: - path_to_video = 'tests/integration/media/sample_image.jpg' + path_to_video = "tests/integration/media/sample_image.jpg" return path_to_video @pytest.fixture def sample_video() -> str: - path_to_video = 'tests/integration/media/cat.mp4' + path_to_video = "tests/integration/media/cat.mp4" return path_to_video @pytest.fixture def sample_bulk_conversation() -> list: - path_to_conversation = 'tests/integration/media/bulk_conversation.json' + path_to_conversation = "tests/integration/media/bulk_conversation.json" with open(path_to_conversation) as json_file: conversations = json.load(json_file) return conversations @@ -599,8 +635,15 @@ def organization(client): @pytest.fixture -def configured_project_with_label(client, rand_gen, image_url, project, dataset, - data_row, wait_for_label_processing): +def configured_project_with_label( + client, + rand_gen, + image_url, + project, + dataset, + data_row, + wait_for_label_processing, +): """Project with a connected dataset, having one datarow Project contains an ontology with 1 bbox tool Additionally includes a create_label method for any needed extra labels @@ -609,16 +652,18 @@ def configured_project_with_label(client, rand_gen, image_url, project, dataset, project._wait_until_data_rows_are_processed( data_row_ids=[data_row.uid], wait_processing_max_seconds=DATA_ROW_PROCESSING_WAIT_TIMEOUT_SECONDS, - sleep_interval=DATA_ROW_PROCESSING_WAIT_SLEEP_INTERNAL_SECONDS) + sleep_interval=DATA_ROW_PROCESSING_WAIT_SLEEP_INTERNAL_SECONDS, + ) project.create_batch( rand_gen(str), [data_row.uid], # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) + 5, # priority between 1(Highest) - 5(lowest) ) ontology = _setup_ontology(project) - label = _create_label(project, data_row, ontology, - wait_for_label_processing) + label = _create_label( + project, data_row, ontology, wait_for_label_processing + ) yield [project, dataset, data_row, label] for label in project.labels(): @@ -626,32 +671,32 @@ def configured_project_with_label(client, rand_gen, image_url, project, dataset, def _create_label(project, data_row, ontology, wait_for_label_processing): - predictions = [{ - "uuid": str(uuid.uuid4()), - "schemaId": ontology.tools[0].feature_schema_id, - "dataRow": { - "id": data_row.uid - }, - "bbox": { - "top": 20, - "left": 20, - "height": 50, - "width": 50 + predictions = [ + { + "uuid": str(uuid.uuid4()), + "schemaId": ontology.tools[0].feature_schema_id, + "dataRow": {"id": data_row.uid}, + "bbox": {"top": 20, "left": 20, "height": 50, "width": 50}, } - }] + ] def create_label(): - """ Ad-hoc function to create a LabelImport + """Ad-hoc function to create a LabelImport Creates a LabelImport task which will create a label """ upload_task = LabelImport.create_from_objects( - project.client, project.uid, f'label-import-{uuid.uuid4()}', - predictions) + project.client, + project.uid, + f"label-import-{uuid.uuid4()}", + predictions, + ) upload_task.wait_until_done(sleep_time_seconds=5) - assert upload_task.state == AnnotationImportState.FINISHED, "Label Import did not finish" - assert len( - upload_task.errors - ) == 0, f"Label Import {upload_task.name} failed with errors {upload_task.errors}" + assert ( + upload_task.state == AnnotationImportState.FINISHED + ), "Label Import did not finish" + assert ( + len(upload_task.errors) == 0 + ), f"Label Import {upload_task.name} failed with errors {upload_task.errors}" project.create_label = create_label project.create_label() @@ -662,10 +707,14 @@ def create_label(): def _setup_ontology(project): editor = list( project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - ontology_builder = OntologyBuilder(tools=[ - Tool(tool=Tool.Type.BBOX, name="test-bbox-class"), - ]) + where=LabelingFrontend.name == "editor" + ) + )[0] + ontology_builder = OntologyBuilder( + tools=[ + Tool(tool=Tool.Type.BBOX, name="test-bbox-class"), + ] + ) project.setup(editor, ontology_builder.asdict()) # TODO: ontology may not be synchronous after setup. remove sleep when api is more consistent time.sleep(2) @@ -674,34 +723,37 @@ def _setup_ontology(project): @pytest.fixture def big_dataset(dataset: Dataset): - task = dataset.create_data_rows([ - { - "row_data": IMAGE_URL, - "external_id": EXTERNAL_ID - }, - ] * 3) + task = dataset.create_data_rows( + [ + {"row_data": IMAGE_URL, "external_id": EXTERNAL_ID}, + ] + * 3 + ) task.wait_till_done() yield dataset @pytest.fixture -def configured_batch_project_with_label(project, dataset, data_row, - wait_for_label_processing): +def configured_batch_project_with_label( + project, dataset, data_row, wait_for_label_processing +): """Project with a batch having one datarow Project contains an ontology with 1 bbox tool Additionally includes a create_label method for any needed extra labels One label is already created and yielded when using fixture """ data_rows = [dr.uid for dr in list(dataset.data_rows())] - project._wait_until_data_rows_are_processed(data_row_ids=data_rows, - sleep_interval=3) + project._wait_until_data_rows_are_processed( + data_row_ids=data_rows, sleep_interval=3 + ) project.create_batch("test-batch", data_rows) project.data_row_ids = data_rows ontology = _setup_ontology(project) - label = _create_label(project, data_row, ontology, - wait_for_label_processing) + label = _create_label( + project, data_row, ontology, wait_for_label_processing + ) yield [project, dataset, data_row, label] @@ -710,15 +762,16 @@ def configured_batch_project_with_label(project, dataset, data_row, @pytest.fixture -def configured_batch_project_with_multiple_datarows(project, dataset, data_rows, - wait_for_label_processing): +def configured_batch_project_with_multiple_datarows( + project, dataset, data_rows, wait_for_label_processing +): """Project with a batch having multiple datarows Project contains an ontology with 1 bbox tool Additionally includes a create_label method for any needed extra labels """ global_keys = [dr.global_key for dr in data_rows] - batch_name = f'batch {uuid.uuid4()}' + batch_name = f"batch {uuid.uuid4()}" project.create_batch(batch_name, global_keys=global_keys) ontology = _setup_ontology(project) @@ -732,15 +785,16 @@ def configured_batch_project_with_multiple_datarows(project, dataset, data_rows, @pytest.fixture -def configured_batch_project_for_labeling_service(project, - data_row_and_global_key): +def configured_batch_project_for_labeling_service( + project, data_row_and_global_key +): """Project with a batch having multiple datarows Project contains an ontology with 1 bbox tool Additionally includes a create_label method for any needed extra labels """ global_keys = [data_row_and_global_key[1]] - batch_name = f'batch {uuid.uuid4()}' + batch_name = f"batch {uuid.uuid4()}" project.create_batch(batch_name, global_keys=global_keys) _setup_ontology(project) @@ -830,12 +884,9 @@ def video_data(client, rand_gen, video_data_row, wait_for_data_row_processing): def create_video_data_row(rand_gen): return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4", - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4-{rand_gen(str)}", - "media_type": - "VIDEO", + "row_data": "https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4", + "global_key": f"https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4-{rand_gen(str)}", + "media_type": "VIDEO", } @@ -857,25 +908,25 @@ def video_data_row(rand_gen): class ExportV2Helpers: - @classmethod - def run_project_export_v2_task(cls, - project, - num_retries=5, - task_name=None, - filters={}, - params={}): + def run_project_export_v2_task( + cls, project, num_retries=5, task_name=None, filters={}, params={} + ): task = None - params = params if params else { - "project_details": True, - "performance_details": False, - "data_row_details": True, - "label_details": True - } - while (num_retries > 0): - task = project.export_v2(task_name=task_name, - filters=filters, - params=params) + params = ( + params + if params + else { + "project_details": True, + "performance_details": False, + "data_row_details": True, + "label_details": True, + } + ) + while num_retries > 0: + task = project.export_v2( + task_name=task_name, filters=filters, params=params + ) task.wait_till_done() assert task.status == "COMPLETE" assert task.errors is None @@ -887,21 +938,19 @@ def run_project_export_v2_task(cls, return task.result @classmethod - def run_dataset_export_v2_task(cls, - dataset, - num_retries=5, - task_name=None, - filters={}, - params={}): + def run_dataset_export_v2_task( + cls, dataset, num_retries=5, task_name=None, filters={}, params={} + ): task = None - params = params if params else { - "performance_details": False, - "label_details": True - } - while (num_retries > 0): - task = dataset.export_v2(task_name=task_name, - filters=filters, - params=params) + params = ( + params + if params + else {"performance_details": False, "label_details": True} + ) + while num_retries > 0: + task = dataset.export_v2( + task_name=task_name, filters=filters, params=params + ) task.wait_till_done() assert task.status == "COMPLETE" assert task.errors is None @@ -914,23 +963,20 @@ def run_dataset_export_v2_task(cls, return task.result @classmethod - def run_catalog_export_v2_task(cls, - client, - num_retries=5, - task_name=None, - filters={}, - params={}): + def run_catalog_export_v2_task( + cls, client, num_retries=5, task_name=None, filters={}, params={} + ): task = None - params = params if params else { - "performance_details": False, - "label_details": True - } + params = ( + params + if params + else {"performance_details": False, "label_details": True} + ) catalog = client.get_catalog() - while (num_retries > 0): - - task = catalog.export_v2(task_name=task_name, - filters=filters, - params=params) + while num_retries > 0: + task = catalog.export_v2( + task_name=task_name, filters=filters, params=params + ) task.wait_till_done() assert task.status == "COMPLETE" assert task.errors is None @@ -956,9 +1002,10 @@ def big_dataset_data_row_ids(big_dataset: Dataset): yield [dr.json["data_row"]["id"] for dr in stream] -@pytest.fixture(scope='function') -def dataset_with_invalid_data_rows(unique_dataset: Dataset, - upload_invalid_data_rows_for_dataset): +@pytest.fixture(scope="function") +def dataset_with_invalid_data_rows( + unique_dataset: Dataset, upload_invalid_data_rows_for_dataset +): upload_invalid_data_rows_for_dataset(unique_dataset) yield unique_dataset @@ -966,22 +1013,25 @@ def dataset_with_invalid_data_rows(unique_dataset: Dataset, @pytest.fixture def upload_invalid_data_rows_for_dataset(): - def _upload_invalid_data_rows_for_dataset(dataset: Dataset): - task = dataset.create_data_rows([ - { - "row_data": 'gs://invalid-bucket/example.png', # forbidden - "external_id": "image-without-access.jpg" - }, - ] * 2) + task = dataset.create_data_rows( + [ + { + "row_data": "gs://invalid-bucket/example.png", # forbidden + "external_id": "image-without-access.jpg", + }, + ] + * 2 + ) task.wait_till_done() return _upload_invalid_data_rows_for_dataset @pytest.fixture -def configured_project(project_with_empty_ontology, initial_dataset, rand_gen, - image_url): +def configured_project( + project_with_empty_ontology, initial_dataset, rand_gen, image_url +): dataset = initial_dataset data_row_id = dataset.create_data_row(row_data=image_url).uid project = project_with_empty_ontology @@ -989,7 +1039,7 @@ def configured_project(project_with_empty_ontology, initial_dataset, rand_gen, batch = project.create_batch( rand_gen(str), [data_row_id], # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) + 5, # priority between 1(Highest) - 5(lowest) ) project.data_row_ids = [data_row_id] @@ -1002,18 +1052,23 @@ def configured_project(project_with_empty_ontology, initial_dataset, rand_gen, def project_with_empty_ontology(project): editor = list( project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] + where=LabelingFrontend.name == "editor" + ) + )[0] empty_ontology = {"tools": [], "classifications": []} project.setup(editor, empty_ontology) yield project @pytest.fixture -def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, - image_url): - project = client.create_project(name=rand_gen(str), - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) +def configured_project_with_complex_ontology( + client, initial_dataset, rand_gen, image_url +): + project = client.create_project( + name=rand_gen(str), + queue_mode=QueueMode.Batch, + media_type=MediaType.Image, + ) dataset = initial_dataset data_row = dataset.create_data_row(row_data=image_url) data_row_ids = [data_row.uid] @@ -1021,13 +1076,15 @@ def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, project.create_batch( rand_gen(str), data_row_ids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) + 5, # priority between 1(Highest) - 5(lowest) ) project.data_row_ids = data_row_ids editor = list( project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] + where=LabelingFrontend.name == "editor" + ) + )[0] ontology = OntologyBuilder() tools = [ @@ -1035,24 +1092,29 @@ def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, Tool(tool=Tool.Type.LINE, name="test-line-class"), Tool(tool=Tool.Type.POINT, name="test-point-class"), Tool(tool=Tool.Type.POLYGON, name="test-polygon-class"), - Tool(tool=Tool.Type.NER, name="test-ner-class") + Tool(tool=Tool.Type.NER, name="test-ner-class"), ] options = [ Option(value="first option answer"), Option(value="second option answer"), - Option(value="third option answer") + Option(value="third option answer"), ] classifications = [ - Classification(class_type=Classification.Type.TEXT, - name="test-text-class"), - Classification(class_type=Classification.Type.RADIO, - name="test-radio-class", - options=options), - Classification(class_type=Classification.Type.CHECKLIST, - name="test-checklist-class", - options=options) + Classification( + class_type=Classification.Type.TEXT, name="test-text-class" + ), + Classification( + class_type=Classification.Type.RADIO, + name="test-radio-class", + options=options, + ), + Classification( + class_type=Classification.Type.CHECKLIST, + name="test-checklist-class", + options=options, + ), ] for t in tools: @@ -1070,7 +1132,6 @@ def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, @pytest.fixture def embedding(client: Client, environ): - uuid_str = uuid.uuid4().hex time.sleep(randint(1, 5)) embedding = client.create_embedding(f"sdk-int-{uuid_str}", 8) @@ -1085,13 +1146,16 @@ def valid_model_id(): @pytest.fixture -def requested_labeling_service(rand_gen, - live_chat_evaluation_project_with_new_dataset, - chat_evaluation_ontology, model_config): +def requested_labeling_service( + rand_gen, + live_chat_evaluation_project_with_new_dataset, + chat_evaluation_ontology, + model_config, +): project = live_chat_evaluation_project_with_new_dataset project.connect_ontology(chat_evaluation_ontology) - project.upsert_instructions('tests/integration/media/sample_pdf.pdf') + project.upsert_instructions("tests/integration/media/sample_pdf.pdf") labeling_service = project.get_labeling_service() project.add_model_config(model_config.uid) diff --git a/libs/labelbox/tests/data/annotation_import/conftest.py b/libs/labelbox/tests/data/annotation_import/conftest.py index 370af0517..39cede0bb 100644 --- a/libs/labelbox/tests/data/annotation_import/conftest.py +++ b/libs/labelbox/tests/data/annotation_import/conftest.py @@ -15,6 +15,7 @@ from labelbox.schema.annotation_import import LabelImport, AnnotationImportState from pytest import FixtureRequest from contextlib import suppress + """ The main fixtures of this library are configured_project and configured_project_by_global_key. Both fixtures generate data rows with a parametrize media type. They create the amount of data rows equal to the DATA_ROW_COUNT variable below. The data rows are generated with a factory fixture that returns a function that allows you to pass a global key. The ontologies are generated normalized and based on the MediaType given (i.e. only features supported by MediaType are created). This ontology is later used to obtain the correct annotations with the prediction_id_mapping and corresponding inferences. Each data row will have all possible annotations attached supported for the MediaType. """ @@ -26,15 +27,11 @@ @pytest.fixture(scope="module", autouse=True) def video_data_row_factory(): - def video_data_row(global_key): return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4", - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4-{global_key}", - "media_type": - "VIDEO", + "row_data": "https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4", + "global_key": f"https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4-{global_key}", + "media_type": "VIDEO", } return video_data_row @@ -42,15 +39,11 @@ def video_data_row(global_key): @pytest.fixture(scope="module", autouse=True) def audio_data_row_factory(): - def audio_data_row(global_key): return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/audio-sample-data/sample-audio-1.mp3", - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/audio-sample-data/sample-audio-1.mp3-{global_key}", - "media_type": - "AUDIO", + "row_data": "https://storage.googleapis.com/labelbox-datasets/audio-sample-data/sample-audio-1.mp3", + "global_key": f"https://storage.googleapis.com/labelbox-datasets/audio-sample-data/sample-audio-1.mp3-{global_key}", + "media_type": "AUDIO", } return audio_data_row @@ -58,13 +51,10 @@ def audio_data_row(global_key): @pytest.fixture(scope="module", autouse=True) def conversational_data_row_factory(): - def conversational_data_row(global_key): return { - "row_data": - "https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json", - "global_key": - f"https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json-{global_key}", + "row_data": "https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json", + "global_key": f"https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json-{global_key}", } return conversational_data_row @@ -72,15 +62,11 @@ def conversational_data_row(global_key): @pytest.fixture(scope="module", autouse=True) def dicom_data_row_factory(): - def dicom_data_row(global_key): return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/dicom-sample-data/sample-dicom-1.dcm", - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/dicom-sample-data/sample-dicom-1.dcm-{global_key}", - "media_type": - "DICOM", + "row_data": "https://storage.googleapis.com/labelbox-datasets/dicom-sample-data/sample-dicom-1.dcm", + "global_key": f"https://storage.googleapis.com/labelbox-datasets/dicom-sample-data/sample-dicom-1.dcm-{global_key}", + "media_type": "DICOM", } return dicom_data_row @@ -88,27 +74,20 @@ def dicom_data_row(global_key): @pytest.fixture(scope="module", autouse=True) def geospatial_data_row_factory(): - def geospatial_data_row(global_key): return { "row_data": { - "tile_layer_url": - "https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/{z}/{x}/{y}.png", + "tile_layer_url": "https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/{z}/{x}/{y}.png", "bounds": [ [19.405662413477728, -99.21052827588443], [19.400498983095076, -99.20534818927473], ], - "min_zoom": - 12, - "max_zoom": - 20, - "epsg": - "EPSG4326", + "min_zoom": 12, + "max_zoom": 20, + "epsg": "EPSG4326", }, - "global_key": - f"https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/z/x/y.png-{global_key}", - "media_type": - "TMS_GEO", + "global_key": f"https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/z/x/y.png-{global_key}", + "media_type": "TMS_GEO", } return geospatial_data_row @@ -116,13 +95,10 @@ def geospatial_data_row(global_key): @pytest.fixture(scope="module", autouse=True) def html_data_row_factory(): - def html_data_row(global_key): return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/html_sample_data/sample_html_1.html", - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/html_sample_data/sample_html_1.html-{global_key}", + "row_data": "https://storage.googleapis.com/labelbox-datasets/html_sample_data/sample_html_1.html", + "global_key": f"https://storage.googleapis.com/labelbox-datasets/html_sample_data/sample_html_1.html-{global_key}", } return html_data_row @@ -130,15 +106,11 @@ def html_data_row(global_key): @pytest.fixture(scope="module", autouse=True) def image_data_row_factory(): - def image_data_row(global_key): return { - "row_data": - "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg", - "global_key": - f"https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-{global_key}", - "media_type": - "IMAGE", + "row_data": "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg", + "global_key": f"https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-{global_key}", + "media_type": "IMAGE", } return image_data_row @@ -146,19 +118,14 @@ def image_data_row(global_key): @pytest.fixture(scope="module", autouse=True) def document_data_row_factory(): - def document_data_row(global_key): return { "row_data": { - "pdf_url": - "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf", - "text_layer_url": - "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483-lb-textlayer.json", + "pdf_url": "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf", + "text_layer_url": "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483-lb-textlayer.json", }, - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf-{global_key}", - "media_type": - "PDF", + "global_key": f"https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf-{global_key}", + "media_type": "PDF", } return document_data_row @@ -166,15 +133,11 @@ def document_data_row(global_key): @pytest.fixture(scope="module", autouse=True) def text_data_row_factory(): - def text_data_row(global_key): return { - "row_data": - "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-2.txt", - "global_key": - f"https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-2.txt-{global_key}", - "media_type": - "TEXT", + "row_data": "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-2.txt", + "global_key": f"https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-2.txt-{global_key}", + "media_type": "TEXT", } return text_data_row @@ -182,13 +145,10 @@ def text_data_row(global_key): @pytest.fixture(scope="module", autouse=True) def llm_human_preference_data_row_factory(): - def llm_human_preference_data_row(global_key): return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/sdk_test/llm_prompt_response_conv.json", - "global_key": - global_key, + "row_data": "https://storage.googleapis.com/labelbox-datasets/sdk_test/llm_prompt_response_conv.json", + "global_key": global_key, } return llm_human_preference_data_row @@ -224,60 +184,50 @@ def normalized_ontology_by_media_type(): """Returns NDJSON of ontology based on media type""" bbox_tool_with_nested_text = { - "required": - False, - "name": - "bbox_tool_with_nested_text", - "tool": - "rectangle", - "color": - "#a23030", - "classifications": [{ - "required": - False, - "instructions": - "nested", - "name": - "nested", - "type": - "radio", - "options": [{ - "label": - "radio_value_1", - "value": - "radio_value_1", + "required": False, + "name": "bbox_tool_with_nested_text", + "tool": "rectangle", + "color": "#a23030", + "classifications": [ + { + "required": False, + "instructions": "nested", + "name": "nested", + "type": "radio", "options": [ { - "required": - False, - "instructions": - "nested_checkbox", - "name": - "nested_checkbox", - "type": - "checklist", + "label": "radio_value_1", + "value": "radio_value_1", "options": [ { - "label": "nested_checkbox_option_1", - "value": "nested_checkbox_option_1", - "options": [], + "required": False, + "instructions": "nested_checkbox", + "name": "nested_checkbox", + "type": "checklist", + "options": [ + { + "label": "nested_checkbox_option_1", + "value": "nested_checkbox_option_1", + "options": [], + }, + { + "label": "nested_checkbox_option_2", + "value": "nested_checkbox_option_2", + }, + ], }, { - "label": "nested_checkbox_option_2", - "value": "nested_checkbox_option_2", + "required": False, + "instructions": "nested_text", + "name": "nested_text", + "type": "text", + "options": [], }, ], }, - { - "required": False, - "instructions": "nested_text", - "name": "nested_text", - "type": "text", - "options": [], - }, ], - },], - }], + } + ], } bbox_tool = { @@ -331,44 +281,35 @@ def normalized_ontology_by_media_type(): "classifications": [], } checklist = { - "required": - False, - "instructions": - "checklist", - "name": - "checklist", - "type": - "checklist", + "required": False, + "instructions": "checklist", + "name": "checklist", + "type": "checklist", "options": [ { "label": "first_checklist_answer", - "value": "first_checklist_answer" + "value": "first_checklist_answer", }, { "label": "second_checklist_answer", - "value": "second_checklist_answer" + "value": "second_checklist_answer", }, ], } checklist_index = { - "required": - False, - "instructions": - "checklist_index", - "name": - "checklist_index", - "type": - "checklist", - "scope": - "index", + "required": False, + "instructions": "checklist_index", + "name": "checklist_index", + "type": "checklist", + "scope": "index", "options": [ { "label": "first_checklist_answer", - "value": "first_checklist_answer" + "value": "first_checklist_answer", }, { "label": "second_checklist_answer", - "value": "second_checklist_answer" + "value": "second_checklist_answer", }, ], } @@ -388,14 +329,10 @@ def normalized_ontology_by_media_type(): "options": [], } radio = { - "required": - False, - "instructions": - "radio", - "name": - "radio", - "type": - "radio", + "required": False, + "instructions": "radio", + "name": "radio", + "type": "radio", "options": [ { "label": "first_radio_answer", @@ -418,39 +355,45 @@ def normalized_ontology_by_media_type(): "maxCharacters": 50, "minCharacters": 1, "schemaNodeId": None, - "type": "prompt" + "type": "prompt", } response_radio = { "instructions": "radio-response", "name": "radio-response", - "options": [{ - "label": "first_radio_answer", - "value": "first_radio_answer", - "options": [] - }, { - "label": "second_radio_answer", - "value": "second_radio_answer", - "options": [] - }], + "options": [ + { + "label": "first_radio_answer", + "value": "first_radio_answer", + "options": [], + }, + { + "label": "second_radio_answer", + "value": "second_radio_answer", + "options": [], + }, + ], "required": True, - "type": "response-radio" + "type": "response-radio", } response_checklist = { "instructions": "checklist-response", "name": "checklist-response", - "options": [{ - "label": "first_checklist_answer", - "value": "first_checklist_answer", - "options": [] - }, { - "label": "second_checklist_answer", - "value": "second_checklist_answer", - "options": [] - }], + "options": [ + { + "label": "first_checklist_answer", + "value": "first_checklist_answer", + "options": [], + }, + { + "label": "second_checklist_answer", + "value": "second_checklist_answer", + "options": [], + }, + ], "required": True, - "type": "response-checklist" + "type": "response-checklist", } response_text = { @@ -459,7 +402,7 @@ def normalized_ontology_by_media_type(): "minCharacters": 1, "name": "response-text", "required": True, - "type": "response-text" + "type": "response-text", } return { @@ -476,7 +419,7 @@ def normalized_ontology_by_media_type(): checklist, free_form_text, radio, - ] + ], }, MediaType.Text: { "tools": [entity_tool], @@ -484,7 +427,7 @@ def normalized_ontology_by_media_type(): checklist, free_form_text, radio, - ] + ], }, MediaType.Video: { "tools": [ @@ -495,9 +438,12 @@ def normalized_ontology_by_media_type(): raster_segmentation_tool, ], "classifications": [ - checklist, free_form_text, radio, checklist_index, - free_form_text_index - ] + checklist, + free_form_text, + radio, + checklist_index, + free_form_text_index, + ], }, MediaType.Geospatial_Tile: { "tools": [ @@ -511,7 +457,7 @@ def normalized_ontology_by_media_type(): checklist, free_form_text, radio, - ] + ], }, MediaType.Document: { "tools": [entity_tool, bbox_tool, bbox_tool_with_nested_text], @@ -519,7 +465,7 @@ def normalized_ontology_by_media_type(): checklist, free_form_text, radio, - ] + ], }, MediaType.Audio: { "tools": [], @@ -527,7 +473,7 @@ def normalized_ontology_by_media_type(): checklist, free_form_text, radio, - ] + ], }, MediaType.Html: { "tools": [], @@ -535,34 +481,42 @@ def normalized_ontology_by_media_type(): checklist, free_form_text, radio, - ] + ], }, MediaType.Dicom: { "tools": [raster_segmentation_tool, polyline_tool], - "classifications": [] + "classifications": [], }, MediaType.Conversational: { "tools": [entity_tool], "classifications": [ - checklist, free_form_text, radio, checklist_index, - free_form_text_index - ] + checklist, + free_form_text, + radio, + checklist_index, + free_form_text_index, + ], }, MediaType.LLMPromptResponseCreation: { "tools": [], "classifications": [ - prompt_text, response_text, response_radio, response_checklist - ] + prompt_text, + response_text, + response_radio, + response_checklist, + ], }, MediaType.LLMPromptCreation: { "tools": [], - "classifications": [prompt_text] + "classifications": [prompt_text], }, OntologyKind.ResponseCreation: { "tools": [], "classifications": [ - response_text, response_radio, response_checklist - ] + response_text, + response_radio, + response_checklist, + ], }, "all": { "tools": [ @@ -581,8 +535,8 @@ def normalized_ontology_by_media_type(): free_form_text, free_form_text_index, radio, - ] - } + ], + }, } @@ -617,7 +571,7 @@ def func(project): @pytest.fixture def hardcoded_datarow_id(): - data_row_id = 'ck8q9q9qj00003g5z3q1q9q9q' + data_row_id = "ck8q9q9qj00003g5z3q1q9q9q" def get_data_row_id(): return data_row_id @@ -639,33 +593,40 @@ def get_global_key(): def _create_response_creation_project( - client: Client, rand_gen, data_row_json_by_media_type, ontology_kind, - normalized_ontology_by_media_type) -> Tuple[Project, Ontology, Dataset]: + client: Client, + rand_gen, + data_row_json_by_media_type, + ontology_kind, + normalized_ontology_by_media_type, +) -> Tuple[Project, Ontology, Dataset]: "For response creation projects" dataset = client.create_dataset(name=rand_gen(str)) project = client.create_response_creation_project( - name=f"{ontology_kind}-{rand_gen(str)}") + name=f"{ontology_kind}-{rand_gen(str)}" + ) ontology = client.create_ontology( name=f"{ontology_kind}-{rand_gen(str)}", normalized=normalized_ontology_by_media_type[ontology_kind], media_type=MediaType.Text, - ontology_kind=ontology_kind) + ontology_kind=ontology_kind, + ) project.connect_ontology(ontology) data_row_data = [] for _ in range(DATA_ROW_COUNT): - data_row_data.append(data_row_json_by_media_type[MediaType.Text]( - rand_gen(str))) + data_row_data.append( + data_row_json_by_media_type[MediaType.Text](rand_gen(str)) + ) task = dataset.create_data_rows(data_row_data) task.wait_till_done() - global_keys = [row['global_key'] for row in task.result] - data_row_ids = [row['id'] for row in task.result] + global_keys = [row["global_key"] for row in task.result] + data_row_ids = [row["id"] for row in task.result] project.create_batch( rand_gen(str), @@ -679,16 +640,15 @@ def _create_response_creation_project( @pytest.fixture -def llm_prompt_response_creation_dataset_with_data_row(client: Client, - rand_gen): +def llm_prompt_response_creation_dataset_with_data_row( + client: Client, rand_gen +): dataset = client.create_dataset(name=rand_gen(str)) global_key = str(uuid.uuid4()) convo_data = { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/conversational-sample-data/pairwise_shopping_2.json", - "global_key": - global_key + "row_data": "https://storage.googleapis.com/labelbox-datasets/conversational-sample-data/pairwise_shopping_2.json", + "global_key": global_key, } task = dataset.create_data_rows([convo_data]) @@ -700,26 +660,33 @@ def llm_prompt_response_creation_dataset_with_data_row(client: Client, def _create_prompt_response_project( - client: Client, rand_gen, media_type, normalized_ontology_by_media_type, - export_v2_test_helpers, llm_prompt_response_creation_dataset_with_data_row + client: Client, + rand_gen, + media_type, + normalized_ontology_by_media_type, + export_v2_test_helpers, + llm_prompt_response_creation_dataset_with_data_row, ) -> Tuple[Project, Ontology]: """For prompt response data row auto gen projects""" dataset = llm_prompt_response_creation_dataset_with_data_row prompt_response_project = client.create_prompt_response_generation_project( name=f"{media_type.value}-{rand_gen(str)}", dataset_id=dataset.uid, - media_type=media_type) + media_type=media_type, + ) ontology = client.create_ontology( name=f"{media_type}-{rand_gen(str)}", normalized=normalized_ontology_by_media_type[media_type], - media_type=media_type) + media_type=media_type, + ) prompt_response_project.connect_ontology(ontology) # We have to export to get data row ids result = export_v2_test_helpers.run_project_export_v2_task( - prompt_response_project) + prompt_response_project + ) data_row_ids = [dr["data_row"]["id"] for dr in result] global_keys = [dr["data_row"]["global_key"] for dr in result] @@ -731,32 +698,39 @@ def _create_prompt_response_project( def _create_project( - client: Client, rand_gen, data_row_json_by_media_type, media_type, - normalized_ontology_by_media_type) -> Tuple[Project, Ontology, Dataset]: - """ Shared function to configure project for integration tests """ + client: Client, + rand_gen, + data_row_json_by_media_type, + media_type, + normalized_ontology_by_media_type, +) -> Tuple[Project, Ontology, Dataset]: + """Shared function to configure project for integration tests""" dataset = client.create_dataset(name=rand_gen(str)) - project = client.create_project(name=f"{media_type}-{rand_gen(str)}", - media_type=media_type) + project = client.create_project( + name=f"{media_type}-{rand_gen(str)}", media_type=media_type + ) ontology = client.create_ontology( name=f"{media_type}-{rand_gen(str)}", normalized=normalized_ontology_by_media_type[media_type], - media_type=media_type) + media_type=media_type, + ) project.connect_ontology(ontology) data_row_data = [] for _ in range(DATA_ROW_COUNT): - data_row_data.append(data_row_json_by_media_type[media_type]( - rand_gen(str))) + data_row_data.append( + data_row_json_by_media_type[media_type](rand_gen(str)) + ) task = dataset.create_data_rows(data_row_data) task.wait_till_done() - global_keys = [row['global_key'] for row in task.result] - data_row_ids = [row['id'] for row in task.result] + global_keys = [row["global_key"] for row in task.result] + data_row_ids = [row["id"] for row in task.result] project.create_batch( rand_gen(str), @@ -770,29 +744,48 @@ def _create_project( @pytest.fixture -def configured_project(client: Client, rand_gen, data_row_json_by_media_type, - request: FixtureRequest, - normalized_ontology_by_media_type, - export_v2_test_helpers, - llm_prompt_response_creation_dataset_with_data_row): +def configured_project( + client: Client, + rand_gen, + data_row_json_by_media_type, + request: FixtureRequest, + normalized_ontology_by_media_type, + export_v2_test_helpers, + llm_prompt_response_creation_dataset_with_data_row, +): """Configure project for test. Request.param will contain the media type if not present will use Image MediaType. The project will have 10 data rows.""" media_type = getattr(request, "param", MediaType.Image) dataset = None - if media_type == MediaType.LLMPromptCreation or media_type == MediaType.LLMPromptResponseCreation: + if ( + media_type == MediaType.LLMPromptCreation + or media_type == MediaType.LLMPromptResponseCreation + ): project, ontology = _create_prompt_response_project( - client, rand_gen, media_type, normalized_ontology_by_media_type, + client, + rand_gen, + media_type, + normalized_ontology_by_media_type, export_v2_test_helpers, - llm_prompt_response_creation_dataset_with_data_row) + llm_prompt_response_creation_dataset_with_data_row, + ) elif media_type == OntologyKind.ResponseCreation: project, ontology, dataset = _create_response_creation_project( - client, rand_gen, data_row_json_by_media_type, media_type, - normalized_ontology_by_media_type) + client, + rand_gen, + data_row_json_by_media_type, + media_type, + normalized_ontology_by_media_type, + ) else: project, ontology, dataset = _create_project( - client, rand_gen, data_row_json_by_media_type, media_type, - normalized_ontology_by_media_type) + client, + rand_gen, + data_row_json_by_media_type, + media_type, + normalized_ontology_by_media_type, + ) yield project @@ -805,28 +798,46 @@ def configured_project(client: Client, rand_gen, data_row_json_by_media_type, @pytest.fixture() -def configured_project_by_global_key(client: Client, rand_gen, - data_row_json_by_media_type, - request: FixtureRequest, - normalized_ontology_by_media_type, - export_v2_test_helpers): +def configured_project_by_global_key( + client: Client, + rand_gen, + data_row_json_by_media_type, + request: FixtureRequest, + normalized_ontology_by_media_type, + export_v2_test_helpers, +): """Does the same thing as configured project but with global keys focus.""" media_type = getattr(request, "param", MediaType.Image) dataset = None - if media_type == MediaType.LLMPromptCreation or media_type == MediaType.LLMPromptResponseCreation: + if ( + media_type == MediaType.LLMPromptCreation + or media_type == MediaType.LLMPromptResponseCreation + ): project, ontology = _create_prompt_response_project( - client, rand_gen, media_type, normalized_ontology_by_media_type, - export_v2_test_helpers) + client, + rand_gen, + media_type, + normalized_ontology_by_media_type, + export_v2_test_helpers, + ) elif media_type == OntologyKind.ResponseCreation: project, ontology, dataset = _create_response_creation_project( - client, rand_gen, data_row_json_by_media_type, media_type, - normalized_ontology_by_media_type) + client, + rand_gen, + data_row_json_by_media_type, + media_type, + normalized_ontology_by_media_type, + ) else: project, ontology, dataset = _create_project( - client, rand_gen, data_row_json_by_media_type, media_type, - normalized_ontology_by_media_type) + client, + rand_gen, + data_row_json_by_media_type, + media_type, + normalized_ontology_by_media_type, + ) yield project @@ -839,25 +850,42 @@ def configured_project_by_global_key(client: Client, rand_gen, @pytest.fixture(scope="module") -def module_project(client: Client, rand_gen, data_row_json_by_media_type, - request: FixtureRequest, normalized_ontology_by_media_type): +def module_project( + client: Client, + rand_gen, + data_row_json_by_media_type, + request: FixtureRequest, + normalized_ontology_by_media_type, +): """Generates a image project that scopes to the test module(file). Used to reduce api calls.""" media_type = getattr(request, "param", MediaType.Image) media_type = getattr(request, "param", MediaType.Image) dataset = None - if media_type == MediaType.LLMPromptCreation or media_type == MediaType.LLMPromptResponseCreation: + if ( + media_type == MediaType.LLMPromptCreation + or media_type == MediaType.LLMPromptResponseCreation + ): project, ontology = _create_prompt_response_project( - client, rand_gen, media_type, normalized_ontology_by_media_type) + client, rand_gen, media_type, normalized_ontology_by_media_type + ) elif media_type == OntologyKind.ResponseCreation: project, ontology, dataset = _create_response_creation_project( - client, rand_gen, data_row_json_by_media_type, media_type, - normalized_ontology_by_media_type) + client, + rand_gen, + data_row_json_by_media_type, + media_type, + normalized_ontology_by_media_type, + ) else: project, ontology, dataset = _create_project( - client, rand_gen, data_row_json_by_media_type, media_type, - normalized_ontology_by_media_type) + client, + rand_gen, + data_row_json_by_media_type, + media_type, + normalized_ontology_by_media_type, + ) yield project @@ -872,17 +900,17 @@ def module_project(client: Client, rand_gen, data_row_json_by_media_type, @pytest.fixture def prediction_id_mapping(request, normalized_ontology_by_media_type): """Creates the base of annotation based on tools inside project ontology. We would want only annotations supported for the MediaType of the ontology and project. Annotations are generated for each data row created later be combined inside the test file. This serves as the base fixture for all the interference (annotations) fixture. This fixtures supports a few strategies: - + Integration test: configured_project: generates data rows with data row id focus. configured_project_by_global_key: generates data rows with global key focus. module_configured_project: configured project but scoped to test module. Unit tests - Individuals can supply hard-coded data row ids or global keys without configured a project must include a media type fixture to get the appropriate annotations. - - Each strategy provides a few items. - + Individuals can supply hard-coded data row ids or global keys without configured a project must include a media type fixture to get the appropriate annotations. + + Each strategy provides a few items. + Labelbox Project (unit testing strategies do not make api calls so will have None for project) Data row identifiers (ids the annotation uses) Ontology: normalized ontology @@ -890,23 +918,23 @@ def prediction_id_mapping(request, normalized_ontology_by_media_type): if "configured_project" in request.fixturenames: project = request.getfixturevalue("configured_project") - data_row_identifiers = [{ - "id": data_row_id - } for data_row_id in project.data_row_ids] + data_row_identifiers = [ + {"id": data_row_id} for data_row_id in project.data_row_ids + ] ontology = project.ontology().normalized elif "configured_project_by_global_key" in request.fixturenames: project = request.getfixturevalue("configured_project_by_global_key") - data_row_identifiers = [{ - "globalKey": global_key - } for global_key in project.global_keys] + data_row_identifiers = [ + {"globalKey": global_key} for global_key in project.global_keys + ] ontology = project.ontology().normalized elif "module_project" in request.fixturenames: project = request.getfixturevalue("module_project") - data_row_identifiers = [{ - "id": data_row_id - } for data_row_id in project.data_row_ids] + data_row_identifiers = [ + {"id": data_row_id} for data_row_id in project.data_row_ids + ] ontology = project.ontology().normalized elif "hardcoded_datarow_id" in request.fixturenames: @@ -915,9 +943,9 @@ def prediction_id_mapping(request, normalized_ontology_by_media_type): project = None media_type = request.getfixturevalue("media_type") ontology = normalized_ontology_by_media_type[media_type] - data_row_identifiers = [{ - "id": request.getfixturevalue("hardcoded_datarow_id")() - }] + data_row_identifiers = [ + {"id": request.getfixturevalue("hardcoded_datarow_id")()} + ] elif "hardcoded_global_key" in request.fixturenames: if "media_type" not in request.fixturenames: @@ -925,9 +953,9 @@ def prediction_id_mapping(request, normalized_ontology_by_media_type): project = None media_type = request.getfixturevalue("media_type") ontology = normalized_ontology_by_media_type[media_type] - data_row_identifiers = [{ - "globalKey": request.getfixturevalue("hardcoded_global_key")() - }] + data_row_identifiers = [ + {"globalKey": request.getfixturevalue("hardcoded_global_key")()} + ] # Used for tests that need access to every ontology else: @@ -939,21 +967,25 @@ def prediction_id_mapping(request, normalized_ontology_by_media_type): base_annotations = [] for data_row_identifier in data_row_identifiers: base_annotation = {} - for feature in (ontology["tools"] + ontology["classifications"]): + for feature in ontology["tools"] + ontology["classifications"]: if "tool" in feature: - feature_type = (feature["tool"] if feature["classifications"] - == [] else f"{feature['tool']}_nested" - ) # tool vs nested classification tool + feature_type = ( + feature["tool"] + if feature["classifications"] == [] + else f"{feature['tool']}_nested" + ) # tool vs nested classification tool else: - feature_type = (feature["type"] if "scope" not in feature else - f"{feature['type']}_{feature['scope']}" - ) # checklist vs indexed checklist + feature_type = ( + feature["type"] + if "scope" not in feature + else f"{feature['type']}_{feature['scope']}" + ) # checklist vs indexed checklist base_annotation[feature_type] = { "uuid": str(uuid.uuid4()), "name": feature["name"], "tool": feature, - "dataRow": data_row_identifier + "dataRow": data_row_identifier, } base_annotations.append(base_annotation) @@ -968,26 +1000,16 @@ def polygon_inference(prediction_id_mapping): if "polygon" not in feature: continue polygon = feature["polygon"].copy() - polygon.update({ - "polygon": [ - { - "x": 147.692, - "y": 118.154 - }, - { - "x": 142.769, - "y": 104.923 - }, - { - "x": 57.846, - "y": 118.769 - }, - { - "x": 28.308, - "y": 169.846 - }, - ] - }) + polygon.update( + { + "polygon": [ + {"x": 147.692, "y": 118.154}, + {"x": 142.769, "y": 104.923}, + {"x": 57.846, "y": 118.769}, + {"x": 28.308, "y": 169.846}, + ] + } + ) del polygon["tool"] polygons.append(polygon) return polygons @@ -1000,14 +1022,11 @@ def rectangle_inference(prediction_id_mapping): if "rectangle" not in feature: continue rectangle = feature["rectangle"].copy() - rectangle.update({ - "bbox": { - "top": 48, - "left": 58, - "height": 65, - "width": 12 - }, - }) + rectangle.update( + { + "bbox": {"top": 48, "left": 58, "height": 65, "width": 12}, + } + ) del rectangle["tool"] rectangles.append(rectangle) return rectangles @@ -1020,34 +1039,35 @@ def rectangle_inference_with_confidence(prediction_id_mapping): if "rectangle_nested" not in feature: continue rectangle = feature["rectangle_nested"].copy() - rectangle.update({ - "bbox": { - "top": 48, - "left": 58, - "height": 65, - "width": 12 - }, - "classifications": [{ - "name": rectangle["tool"]["classifications"][0]["name"], - "answer": { - "name": - rectangle["tool"]["classifications"][0]["options"][0] - ["value"], - "classifications": [{ - "name": - rectangle["tool"]["classifications"][0]["options"] - [0]["options"][1]["name"], - "answer": - "nested answer", - }], - }, - }], - }) + rectangle.update( + { + "bbox": {"top": 48, "left": 58, "height": 65, "width": 12}, + "classifications": [ + { + "name": rectangle["tool"]["classifications"][0]["name"], + "answer": { + "name": rectangle["tool"]["classifications"][0][ + "options" + ][0]["value"], + "classifications": [ + { + "name": rectangle["tool"][ + "classifications" + ][0]["options"][0]["options"][1]["name"], + "answer": "nested answer", + } + ], + }, + } + ], + } + ) rectangle.update({"confidence": 0.9}) rectangle["classifications"][0]["answer"]["confidence"] = 0.8 rectangle["classifications"][0]["answer"]["classifications"][0][ - "confidence"] = 0.7 + "confidence" + ] = 0.7 del rectangle["tool"] rectangles.append(rectangle) @@ -1071,15 +1091,14 @@ def line_inference(prediction_id_mapping): if "line" not in feature: continue line = feature["line"].copy() - line.update({ - "line": [{ - "x": 147.692, - "y": 118.154 - }, { - "x": 150.692, - "y": 160.154 - }] - }) + line.update( + { + "line": [ + {"x": 147.692, "y": 118.154}, + {"x": 150.692, "y": 160.154}, + ] + } + ) del line["tool"] lines.append(line) return lines @@ -1093,24 +1112,20 @@ def line_inference_v2(prediction_id_mapping): continue line = feature["line"].copy() line_data = { - "groupKey": - "axial", - "segments": [{ - "keyframes": [{ - "frame": - 1, - "line": [ - { - "x": 147.692, - "y": 118.154 - }, + "groupKey": "axial", + "segments": [ + { + "keyframes": [ { - "x": 150.692, - "y": 160.154 - }, - ], - }] - },], + "frame": 1, + "line": [ + {"x": 147.692, "y": 118.154}, + {"x": 150.692, "y": 160.154}, + ], + } + ] + }, + ], } line.update(line_data) del line["tool"] @@ -1151,13 +1166,12 @@ def entity_inference_index(prediction_id_mapping): if "named-entity" not in feature: continue entity = feature["named-entity"].copy() - entity.update({ - "location": { - "start": 0, - "end": 8 - }, - "messageId": "0", - }) + entity.update( + { + "location": {"start": 0, "end": 8}, + "messageId": "0", + } + ) del entity["tool"] named_entities.append(entity) return named_entities @@ -1171,20 +1185,22 @@ def entity_inference_document(prediction_id_mapping): continue entity = feature["named-entity"].copy() document_selections = { - "textSelections": [{ - "tokenIds": [ - "3f984bf3-1d61-44f5-b59a-9658a2e3440f", - "3bf00b56-ff12-4e52-8cc1-08dbddb3c3b8", - "6e1c3420-d4b7-4c5a-8fd6-ead43bf73d80", - "87a43d32-af76-4a1d-b262-5c5f4d5ace3a", - "e8606e8a-dfd9-4c49-a635-ad5c879c75d0", - "67c7c19e-4654-425d-bf17-2adb8cf02c30", - "149c5e80-3e07-49a7-ab2d-29ddfe6a38fa", - "b0e94071-2187-461e-8e76-96c58738a52c", - ], - "groupId": "2f4336f4-a07e-4e0a-a9e1-5629b03b719b", - "page": 1, - }] + "textSelections": [ + { + "tokenIds": [ + "3f984bf3-1d61-44f5-b59a-9658a2e3440f", + "3bf00b56-ff12-4e52-8cc1-08dbddb3c3b8", + "6e1c3420-d4b7-4c5a-8fd6-ead43bf73d80", + "87a43d32-af76-4a1d-b262-5c5f4d5ace3a", + "e8606e8a-dfd9-4c49-a635-ad5c879c75d0", + "67c7c19e-4654-425d-bf17-2adb8cf02c30", + "149c5e80-3e07-49a7-ab2d-29ddfe6a38fa", + "b0e94071-2187-461e-8e76-96c58738a52c", + ], + "groupId": "2f4336f4-a07e-4e0a-a9e1-5629b03b719b", + "page": 1, + } + ] } entity.update(document_selections) del entity["tool"] @@ -1199,13 +1215,14 @@ def segmentation_inference(prediction_id_mapping): if "superpixel" not in feature: continue segmentation = feature["superpixel"].copy() - segmentation.update({ - "mask": { - "instanceURI": - "https://storage.googleapis.com/labelbox-datasets/image_sample_data/raster_seg.png", - "colorRGB": (255, 255, 255), + segmentation.update( + { + "mask": { + "instanceURI": "https://storage.googleapis.com/labelbox-datasets/image_sample_data/raster_seg.png", + "colorRGB": (255, 255, 255), + } } - }) + ) del segmentation["tool"] superpixel_masks.append(segmentation) return superpixel_masks @@ -1218,13 +1235,12 @@ def segmentation_inference_rle(prediction_id_mapping): if "superpixel" not in feature: continue segmentation = feature["superpixel"].copy() - segmentation.update({ - "uuid": str(uuid.uuid4()), - "mask": { - "size": [10, 10], - "counts": [1, 0, 10, 100] - }, - }) + segmentation.update( + { + "uuid": str(uuid.uuid4()), + "mask": {"size": [10, 10], "counts": [1, 0, 10, 100]}, + } + ) del segmentation["tool"] superpixel_masks.append(segmentation) return superpixel_masks @@ -1237,12 +1253,14 @@ def segmentation_inference_png(prediction_id_mapping): if "superpixel" not in feature: continue segmentation = feature["superpixel"].copy() - segmentation.update({ - "uuid": str(uuid.uuid4()), - "mask": { - "png": "somedata", - }, - }) + segmentation.update( + { + "uuid": str(uuid.uuid4()), + "mask": { + "png": "somedata", + }, + } + ) del segmentation["tool"] superpixel_masks.append(segmentation) return superpixel_masks @@ -1255,13 +1273,14 @@ def checklist_inference(prediction_id_mapping): if "checklist" not in feature: continue checklist = feature["checklist"].copy() - checklist.update({ - "answers": [{ - "name": "first_checklist_answer" - }, { - "name": "second_checklist_answer" - }] - }) + checklist.update( + { + "answers": [ + {"name": "first_checklist_answer"}, + {"name": "second_checklist_answer"}, + ] + } + ) del checklist["tool"] checklists.append(checklist) return checklists @@ -1274,14 +1293,15 @@ def checklist_inference_index(prediction_id_mapping): if "checklist_index" not in feature: return None checklist = feature["checklist_index"].copy() - checklist.update({ - "answers": [{ - "name": "first_checklist_answer" - }, { - "name": "second_checklist_answer" - }], - "messageId": "0", - }) + checklist.update( + { + "answers": [ + {"name": "first_checklist_answer"}, + {"name": "second_checklist_answer"}, + ], + "messageId": "0", + } + ) del checklist["tool"] checklists.append(checklist) return checklists @@ -1307,11 +1327,11 @@ def radio_response_inference(prediction_id_mapping): if "response-radio" not in feature: continue response_radio = feature["response-radio"].copy() - response_radio.update({ - "answer": { - "name": "first_radio_answer" - }, - }) + response_radio.update( + { + "answer": {"name": "first_radio_answer"}, + } + ) del response_radio["tool"] response_radios.append(response_radio) return response_radios @@ -1324,13 +1344,14 @@ def checklist_response_inference(prediction_id_mapping): if "response-checklist" not in feature: continue response_checklist = feature["response-checklist"].copy() - response_checklist.update({ - "answer": [{ - "name": "first_checklist_answer" - }, { - "name": "second_checklist_answer" - }] - }) + response_checklist.update( + { + "answer": [ + {"name": "first_checklist_answer"}, + {"name": "second_checklist_answer"}, + ] + } + ) del response_checklist["tool"] response_checklists.append(response_checklist) return response_checklists @@ -1392,25 +1413,29 @@ def video_checklist_inference(prediction_id_mapping): if "checklist" not in feature: continue checklist = feature["checklist"].copy() - checklist.update({ - "answers": [{ - "name": "first_checklist_answer" - }, { - "name": "second_checklist_answer" - }] - }) + checklist.update( + { + "answers": [ + {"name": "first_checklist_answer"}, + {"name": "second_checklist_answer"}, + ] + } + ) checklist.update( - {"frames": [ - { - "start": 7, - "end": 13, - }, - { - "start": 18, - "end": 19, - }, - ]}) + { + "frames": [ + { + "start": 7, + "end": 13, + }, + { + "start": 18, + "end": 19, + }, + ] + } + ) del checklist["tool"] checklists.append(checklist) return checklists @@ -1418,13 +1443,24 @@ def video_checklist_inference(prediction_id_mapping): @pytest.fixture def annotations_by_media_type( - polygon_inference, rectangle_inference, rectangle_inference_document, - line_inference_v2, line_inference, entity_inference, - entity_inference_index, entity_inference_document, - checklist_inference_index, text_inference_index, checklist_inference, - text_inference, video_checklist_inference, prompt_text_inference, - checklist_response_inference, radio_response_inference, - text_response_inference): + polygon_inference, + rectangle_inference, + rectangle_inference_document, + line_inference_v2, + line_inference, + entity_inference, + entity_inference_index, + entity_inference_document, + checklist_inference_index, + text_inference_index, + checklist_inference, + text_inference, + video_checklist_inference, + prompt_text_inference, + checklist_response_inference, + radio_response_inference, + text_response_inference, +): return { MediaType.Audio: [checklist_inference, text_inference], MediaType.Conversational: [ @@ -1450,22 +1486,26 @@ def annotations_by_media_type( MediaType.Text: [checklist_inference, text_inference, entity_inference], MediaType.Video: [video_checklist_inference], MediaType.LLMPromptResponseCreation: [ - prompt_text_inference, text_response_inference, - checklist_response_inference, radio_response_inference + prompt_text_inference, + text_response_inference, + checklist_response_inference, + radio_response_inference, ], MediaType.LLMPromptCreation: [prompt_text_inference], OntologyKind.ResponseCreation: [ - text_response_inference, checklist_response_inference, - radio_response_inference - ] + text_response_inference, + checklist_response_inference, + radio_response_inference, + ], } @pytest.fixture -def model_run_predictions(polygon_inference, rectangle_inference, - line_inference): +def model_run_predictions( + polygon_inference, rectangle_inference, line_inference +): # Not supporting mask since there isn't a signed url representing a seg mask to upload - return (polygon_inference + rectangle_inference + line_inference) + return polygon_inference + rectangle_inference + line_inference @pytest.fixture @@ -1476,17 +1516,28 @@ def object_predictions( entity_inference, segmentation_inference, ): - return (polygon_inference + rectangle_inference + line_inference + - entity_inference + segmentation_inference) + return ( + polygon_inference + + rectangle_inference + + line_inference + + entity_inference + + segmentation_inference + ) @pytest.fixture -def object_predictions_for_annotation_import(polygon_inference, - rectangle_inference, - line_inference, - segmentation_inference): - return (polygon_inference + rectangle_inference + line_inference + - segmentation_inference) +def object_predictions_for_annotation_import( + polygon_inference, + rectangle_inference, + line_inference, + segmentation_inference, +): + return ( + polygon_inference + + rectangle_inference + + line_inference + + segmentation_inference + ) @pytest.fixture @@ -1561,8 +1612,9 @@ def model_run_with_data_rows( model_run_predictions, ) upload_task.wait_until_done() - assert (upload_task.state == AnnotationImportState.FINISHED - ), "Label Import did not finish" + assert ( + upload_task.state == AnnotationImportState.FINISHED + ), "Label Import did not finish" assert ( len(upload_task.errors) == 0 ), f"Label Import {upload_task.name} failed with errors {upload_task.errors}" @@ -1574,12 +1626,16 @@ def model_run_with_data_rows( @pytest.fixture -def model_run_with_all_project_labels(client, configured_project, - model_run_predictions, - model_run: ModelRun, - wait_for_label_processing): +def model_run_with_all_project_labels( + client, + configured_project, + model_run_predictions, + model_run: ModelRun, + wait_for_label_processing, +): use_data_row_ids = list( - set([p["dataRow"]["id"] for p in model_run_predictions])) + set([p["dataRow"]["id"] for p in model_run_predictions]) + ) model_run.upsert_data_rows(use_data_row_ids) @@ -1590,8 +1646,9 @@ def model_run_with_all_project_labels(client, configured_project, model_run_predictions, ) upload_task.wait_until_done() - assert (upload_task.state == AnnotationImportState.FINISHED - ), "Label Import did not finish" + assert ( + upload_task.state == AnnotationImportState.FINISHED + ), "Label Import did not finish" assert ( len(upload_task.errors) == 0 ), f"Label Import {upload_task.name} failed with errors {upload_task.errors}" @@ -1603,7 +1660,6 @@ def model_run_with_all_project_labels(client, configured_project, class AnnotationImportTestHelpers: - @classmethod def assert_file_content(cls, url: str, predictions): response = requests.get(url) @@ -1644,34 +1700,16 @@ def expected_export_v2_image(): exported_annotations = { "objects": [ { - "name": - "polygon", - "value": - "polygon", - "annotation_kind": - "ImagePolygon", + "name": "polygon", + "value": "polygon", + "annotation_kind": "ImagePolygon", "classifications": [], "polygon": [ - { - "x": 147.692, - "y": 118.154 - }, - { - "x": 142.769, - "y": 104.923 - }, - { - "x": 57.846, - "y": 118.769 - }, - { - "x": 28.308, - "y": 169.846 - }, - { - "x": 147.692, - "y": 118.154 - }, + {"x": 147.692, "y": 118.154}, + {"x": 142.769, "y": 104.923}, + {"x": 57.846, "y": 118.769}, + {"x": 28.308, "y": 169.846}, + {"x": 147.692, "y": 118.154}, ], }, { @@ -1687,44 +1725,37 @@ def expected_export_v2_image(): }, }, { - "name": - "polyline", - "value": - "polyline", - "annotation_kind": - "ImagePolyline", + "name": "polyline", + "value": "polyline", + "annotation_kind": "ImagePolyline", "classifications": [], - "line": [{ - "x": 147.692, - "y": 118.154 - }, { - "x": 150.692, - "y": 160.154 - }], + "line": [ + {"x": 147.692, "y": 118.154}, + {"x": 150.692, "y": 160.154}, + ], }, ], "classifications": [ { - "name": - "checklist", - "value": - "checklist", - "checklist_answers": [{ - "name": "first_checklist_answer", - "value": "first_checklist_answer", - "classifications": [] - }, { - "name": "second_checklist_answer", - "value": "second_checklist_answer", - "classifications": [] - }], + "name": "checklist", + "value": "checklist", + "checklist_answers": [ + { + "name": "first_checklist_answer", + "value": "first_checklist_answer", + "classifications": [], + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [], + }, + ], }, { "name": "text", "value": "text", - "text_answer": { - "content": "free form text..." - }, + "text_answer": {"content": "free form text..."}, }, ], "relationships": [], @@ -1738,30 +1769,29 @@ def expected_export_v2_audio(): expected_annotations = { "classifications": [ { - "name": - "checklist", - "value": - "checklist", - "checklist_answers": [{ - "name": "first_checklist_answer", - "value": "first_checklist_answer", - "classifications": [] - }, { - "name": "second_checklist_answer", - "value": "second_checklist_answer", - "classifications": [] - }], + "name": "checklist", + "value": "checklist", + "checklist_answers": [ + { + "name": "first_checklist_answer", + "value": "first_checklist_answer", + "classifications": [], + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [], + }, + ], }, { "name": "text", "value": "text", - "text_answer": { - "content": "free form text..." - }, + "text_answer": {"content": "free form text..."}, }, ], "segments": {}, - "timestamp": {} + "timestamp": {}, } return expected_annotations @@ -1774,24 +1804,23 @@ def expected_export_v2_html(): { "name": "text", "value": "text", - "text_answer": { - "content": "free form text..." - }, + "text_answer": {"content": "free form text..."}, }, { - "name": - "checklist", - "value": - "checklist", - "checklist_answers": [{ - "name": "first_checklist_answer", - "value": "first_checklist_answer", - "classifications": [] - }, { - "name": "second_checklist_answer", - "value": "second_checklist_answer", - "classifications": [] - }], + "name": "checklist", + "value": "checklist", + "checklist_answers": [ + { + "name": "first_checklist_answer", + "value": "first_checklist_answer", + "classifications": [], + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [], + }, + ], }, ], "relationships": [], @@ -1802,39 +1831,40 @@ def expected_export_v2_html(): @pytest.fixture() def expected_export_v2_text(): expected_annotations = { - "objects": [{ - "name": "named-entity", - "value": "named_entity", - "annotation_kind": "TextEntity", - "classifications": [], - 'location': { - 'start': 112, - 'end': 128, - 'token': "research suggests" - }, - }], + "objects": [ + { + "name": "named-entity", + "value": "named_entity", + "annotation_kind": "TextEntity", + "classifications": [], + "location": { + "start": 112, + "end": 128, + "token": "research suggests", + }, + } + ], "classifications": [ { - "name": - "checklist", - "value": - "checklist", - "checklist_answers": [{ - "name": "first_checklist_answer", - "value": "first_checklist_answer", - "classifications": [] - }, { - "name": "second_checklist_answer", - "value": "second_checklist_answer", - "classifications": [] - }], + "name": "checklist", + "value": "checklist", + "checklist_answers": [ + { + "name": "first_checklist_answer", + "value": "first_checklist_answer", + "classifications": [], + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [], + }, + ], }, { "name": "text", "value": "text", - "text_answer": { - "content": "free form text..." - }, + "text_answer": {"content": "free form text..."}, }, ], "relationships": [], @@ -1846,25 +1876,26 @@ def expected_export_v2_text(): def expected_export_v2_video(): expected_annotations = { "frames": {}, - "segments": { - "": [[7, 13], [18, 19]] - }, + "segments": {"": [[7, 13], [18, 19]]}, "key_frame_feature_map": {}, - "classifications": [{ - "name": - "checklist", - "value": - "checklist", - "checklist_answers": [{ - "name": "first_checklist_answer", - "value": "first_checklist_answer", - "classifications": [] - }, { - "name": "second_checklist_answer", - "value": "second_checklist_answer", - "classifications": [] - }], - }], + "classifications": [ + { + "name": "checklist", + "value": "checklist", + "checklist_answers": [ + { + "name": "first_checklist_answer", + "value": "first_checklist_answer", + "classifications": [], + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [], + }, + ], + } + ], } return expected_annotations @@ -1872,44 +1903,41 @@ def expected_export_v2_video(): @pytest.fixture() def expected_export_v2_conversation(): expected_annotations = { - "objects": [{ - "name": "named-entity", - "value": "named_entity", - "annotation_kind": "ConversationalTextEntity", - "classifications": [], - "conversational_location": { - "message_id": "0", - "location": { - "start": 0, - "end": 8 + "objects": [ + { + "name": "named-entity", + "value": "named_entity", + "annotation_kind": "ConversationalTextEntity", + "classifications": [], + "conversational_location": { + "message_id": "0", + "location": {"start": 0, "end": 8}, }, - }, - }], + } + ], "classifications": [ { - "name": - "checklist_index", - "value": - "checklist_index", - "message_id": - "0", - "conversational_checklist_answers": [{ - "name": "first_checklist_answer", - "value": "first_checklist_answer", - "classifications": [] - }, { - "name": "second_checklist_answer", - "value": "second_checklist_answer", - "classifications": [] - }], + "name": "checklist_index", + "value": "checklist_index", + "message_id": "0", + "conversational_checklist_answers": [ + { + "name": "first_checklist_answer", + "value": "first_checklist_answer", + "classifications": [], + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [], + }, + ], }, { "name": "text_index", "value": "text_index", "message_id": "0", - "conversational_text_answer": { - "content": "free form text..." - }, + "conversational_text_answer": {"content": "free form text..."}, }, ], "relationships": [], @@ -1928,22 +1956,13 @@ def expected_export_v2_dicom(): "1": { "objects": { "": { - "name": - "polyline", - "value": - "polyline", - "annotation_kind": - "DICOMPolyline", + "name": "polyline", + "value": "polyline", + "annotation_kind": "DICOMPolyline", "classifications": [], "line": [ - { - "x": 147.692, - "y": 118.154 - }, - { - "x": 150.692, - "y": 160.154 - }, + {"x": 147.692, "y": 118.154}, + {"x": 150.692, "y": 160.154}, ], } }, @@ -1954,30 +1973,18 @@ def expected_export_v2_dicom(): "Sagittal": { "name": "Sagittal", "classifications": [], - "frames": {} - }, - "Coronal": { - "name": "Coronal", - "classifications": [], - "frames": {} + "frames": {}, }, + "Coronal": {"name": "Coronal", "classifications": [], "frames": {}}, }, "segments": { - "Axial": { - "": [[1, 1]] - }, + "Axial": {"": [[1, 1]]}, "Sagittal": {}, - "Coronal": {} + "Coronal": {}, }, "classifications": [], "key_frame_feature_map": { - "": { - "Axial": { - "1": True - }, - "Coronal": {}, - "Sagittal": {} - } + "": {"Axial": {"1": True}, "Coronal": {}, "Sagittal": {}} }, } return expected_annotations @@ -1993,24 +2000,23 @@ def expected_export_v2_document(): "annotation_kind": "DocumentEntityToken", "classifications": [], "location": { - "groups": [{ - "id": - "2f4336f4-a07e-4e0a-a9e1-5629b03b719b", - "page_number": - 1, - "tokens": [ - "3f984bf3-1d61-44f5-b59a-9658a2e3440f", - "3bf00b56-ff12-4e52-8cc1-08dbddb3c3b8", - "6e1c3420-d4b7-4c5a-8fd6-ead43bf73d80", - "87a43d32-af76-4a1d-b262-5c5f4d5ace3a", - "e8606e8a-dfd9-4c49-a635-ad5c879c75d0", - "67c7c19e-4654-425d-bf17-2adb8cf02c30", - "149c5e80-3e07-49a7-ab2d-29ddfe6a38fa", - "b0e94071-2187-461e-8e76-96c58738a52c", - ], - "text": - "Metal-insulator (MI) transitions have been one of the", - }] + "groups": [ + { + "id": "2f4336f4-a07e-4e0a-a9e1-5629b03b719b", + "page_number": 1, + "tokens": [ + "3f984bf3-1d61-44f5-b59a-9658a2e3440f", + "3bf00b56-ff12-4e52-8cc1-08dbddb3c3b8", + "6e1c3420-d4b7-4c5a-8fd6-ead43bf73d80", + "87a43d32-af76-4a1d-b262-5c5f4d5ace3a", + "e8606e8a-dfd9-4c49-a635-ad5c879c75d0", + "67c7c19e-4654-425d-bf17-2adb8cf02c30", + "149c5e80-3e07-49a7-ab2d-29ddfe6a38fa", + "b0e94071-2187-461e-8e76-96c58738a52c", + ], + "text": "Metal-insulator (MI) transitions have been one of the", + } + ] }, }, { @@ -2029,26 +2035,25 @@ def expected_export_v2_document(): ], "classifications": [ { - "name": - "checklist", - "value": - "checklist", - "checklist_answers": [{ - "name": "first_checklist_answer", - "value": "first_checklist_answer", - "classifications": [] - }, { - "name": "second_checklist_answer", - "value": "second_checklist_answer", - "classifications": [] - }], + "name": "checklist", + "value": "checklist", + "checklist_answers": [ + { + "name": "first_checklist_answer", + "value": "first_checklist_answer", + "classifications": [], + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [], + }, + ], }, { "name": "text", "value": "text", - "text_answer": { - "content": "free form text..." - }, + "text_answer": {"content": "free form text..."}, }, ], "relationships": [], @@ -2064,39 +2069,38 @@ def expected_export_v2_llm_prompt_response_creation(): { "name": "prompt-text", "value": "prompt-text", - "text_answer": { - "content": "free form text..." - }, + "text_answer": {"content": "free form text..."}, }, { - 'name': 'response-text', - 'text_answer': { - 'content': 'free form text...' - }, - 'value': 'response-text' + "name": "response-text", + "text_answer": {"content": "free form text..."}, + "value": "response-text", }, { - 'checklist_answers': [{ - 'classifications': [], - 'name': 'first_checklist_answer', - 'value': 'first_checklist_answer' - }, { - 'classifications': [], - 'name': 'second_checklist_answer', - 'value': 'second_checklist_answer' - }], - 'name': 'checklist-response', - 'value': 'checklist-response' + "checklist_answers": [ + { + "classifications": [], + "name": "first_checklist_answer", + "value": "first_checklist_answer", + }, + { + "classifications": [], + "name": "second_checklist_answer", + "value": "second_checklist_answer", + }, + ], + "name": "checklist-response", + "value": "checklist-response", }, { - 'name': 'radio-response', - 'radio_answer': { - 'classifications': [], - 'name': 'first_radio_answer', - 'value': 'first_radio_answer' + "name": "radio-response", + "radio_answer": { + "classifications": [], + "name": "first_radio_answer", + "value": "first_radio_answer", }, - 'name': 'radio-response', - 'value': 'radio-response' + "name": "radio-response", + "value": "radio-response", }, ], "relationships": [], @@ -2108,13 +2112,13 @@ def expected_export_v2_llm_prompt_response_creation(): def expected_export_v2_llm_prompt_creation(): expected_annotations = { "objects": [], - "classifications": [{ - "name": "prompt-text", - "value": "prompt-text", - "text_answer": { - "content": "free form text..." + "classifications": [ + { + "name": "prompt-text", + "value": "prompt-text", + "text_answer": {"content": "free form text..."}, }, - },], + ], "relationships": [], } return expected_annotations @@ -2123,38 +2127,39 @@ def expected_export_v2_llm_prompt_creation(): @pytest.fixture() def expected_export_v2_llm_response_creation(): expected_annotations = { - 'objects': [], - 'relationships': [], + "objects": [], + "relationships": [], "classifications": [ { - 'name': 'response-text', - 'text_answer': { - 'content': 'free form text...' - }, - 'value': 'response-text' + "name": "response-text", + "text_answer": {"content": "free form text..."}, + "value": "response-text", }, { - 'checklist_answers': [{ - 'classifications': [], - 'name': 'first_checklist_answer', - 'value': 'first_checklist_answer' - }, { - 'classifications': [], - 'name': 'second_checklist_answer', - 'value': 'second_checklist_answer' - }], - 'name': 'checklist-response', - 'value': 'checklist-response' + "checklist_answers": [ + { + "classifications": [], + "name": "first_checklist_answer", + "value": "first_checklist_answer", + }, + { + "classifications": [], + "name": "second_checklist_answer", + "value": "second_checklist_answer", + }, + ], + "name": "checklist-response", + "value": "checklist-response", }, { - 'name': 'radio-response', - 'radio_answer': { - 'classifications': [], - 'name': 'first_radio_answer', - 'value': 'first_radio_answer' + "name": "radio-response", + "radio_answer": { + "classifications": [], + "name": "first_radio_answer", + "value": "first_radio_answer", }, - 'name': 'radio-response', - 'value': 'radio-response' + "name": "radio-response", + "value": "radio-response", }, ], } @@ -2162,43 +2167,35 @@ def expected_export_v2_llm_response_creation(): @pytest.fixture -def exports_v2_by_media_type(expected_export_v2_image, expected_export_v2_audio, - expected_export_v2_html, expected_export_v2_text, - expected_export_v2_video, - expected_export_v2_conversation, - expected_export_v2_dicom, - expected_export_v2_document, - expected_export_v2_llm_prompt_response_creation, - expected_export_v2_llm_prompt_creation, - expected_export_v2_llm_response_creation): +def exports_v2_by_media_type( + expected_export_v2_image, + expected_export_v2_audio, + expected_export_v2_html, + expected_export_v2_text, + expected_export_v2_video, + expected_export_v2_conversation, + expected_export_v2_dicom, + expected_export_v2_document, + expected_export_v2_llm_prompt_response_creation, + expected_export_v2_llm_prompt_creation, + expected_export_v2_llm_response_creation, +): return { - MediaType.Image: - expected_export_v2_image, - MediaType.Audio: - expected_export_v2_audio, - MediaType.Html: - expected_export_v2_html, - MediaType.Text: - expected_export_v2_text, - MediaType.Video: - expected_export_v2_video, - MediaType.Conversational: - expected_export_v2_conversation, - MediaType.Dicom: - expected_export_v2_dicom, - MediaType.Document: - expected_export_v2_document, - MediaType.LLMPromptResponseCreation: - expected_export_v2_llm_prompt_response_creation, - MediaType.LLMPromptCreation: - expected_export_v2_llm_prompt_creation, - OntologyKind.ResponseCreation: - expected_export_v2_llm_response_creation + MediaType.Image: expected_export_v2_image, + MediaType.Audio: expected_export_v2_audio, + MediaType.Html: expected_export_v2_html, + MediaType.Text: expected_export_v2_text, + MediaType.Video: expected_export_v2_video, + MediaType.Conversational: expected_export_v2_conversation, + MediaType.Dicom: expected_export_v2_dicom, + MediaType.Document: expected_export_v2_document, + MediaType.LLMPromptResponseCreation: expected_export_v2_llm_prompt_response_creation, + MediaType.LLMPromptCreation: expected_export_v2_llm_prompt_creation, + OntologyKind.ResponseCreation: expected_export_v2_llm_response_creation, } class Helpers: - @staticmethod def remove_keys_recursive(d, keys): for k in keys: @@ -2230,7 +2227,6 @@ def rename_cuid_key_recursive(d): @staticmethod def set_project_media_type_from_data_type(project, data_type_class): - def to_pascal_case(name: str) -> str: return "".join([word.capitalize() for word in name.split("_")]) @@ -2250,7 +2246,7 @@ def to_pascal_case(name: str) -> str: @staticmethod def find_data_row_filter(data_row): - return lambda dr: dr['data_row']['id'] == data_row.uid + return lambda dr: dr["data_row"]["id"] == data_row.uid @pytest.fixture diff --git a/libs/labelbox/tests/data/annotation_import/test_annotation_import_limit.py b/libs/labelbox/tests/data/annotation_import/test_annotation_import_limit.py index 297f45c52..dec20fbb5 100644 --- a/libs/labelbox/tests/data/annotation_import/test_annotation_import_limit.py +++ b/libs/labelbox/tests/data/annotation_import/test_annotation_import_limit.py @@ -1,33 +1,56 @@ import itertools import uuid -from labelbox.schema.annotation_import import AnnotationImport, MALPredictionImport +from labelbox.schema.annotation_import import ( + AnnotationImport, + MALPredictionImport, +) from labelbox.schema.media_type import MediaType import pytest from unittest.mock import patch -@patch('labelbox.schema.annotation_import.ANNOTATION_PER_LABEL_LIMIT', 1) -def test_above_annotation_limit_on_single_import_on_single_data_row(annotations_by_media_type): - - annotations_ndjson = list(itertools.chain.from_iterable(annotations_by_media_type[MediaType.Image])) +@patch("labelbox.schema.annotation_import.ANNOTATION_PER_LABEL_LIMIT", 1) +def test_above_annotation_limit_on_single_import_on_single_data_row( + annotations_by_media_type, +): + annotations_ndjson = list( + itertools.chain.from_iterable( + annotations_by_media_type[MediaType.Image] + ) + ) data_row_id = annotations_ndjson[0]["dataRow"]["id"] - data_row_annotations = [annotation for annotation in annotations_ndjson if annotation["dataRow"]["id"] == data_row_id and "bbox" in annotation] - - with pytest.raises(ValueError): - AnnotationImport._validate_data_rows([data_row_annotations[0]]*2) + data_row_annotations = [ + annotation + for annotation in annotations_ndjson + if annotation["dataRow"]["id"] == data_row_id and "bbox" in annotation + ] + with pytest.raises(ValueError): + AnnotationImport._validate_data_rows([data_row_annotations[0]] * 2) -@patch('labelbox.schema.annotation_import.ANNOTATION_PER_LABEL_LIMIT', 1) -def test_above_annotation_limit_divided_among_different_rows(annotations_by_media_type): - annotations_ndjson = list(itertools.chain.from_iterable(annotations_by_media_type[MediaType.Image])) +@patch("labelbox.schema.annotation_import.ANNOTATION_PER_LABEL_LIMIT", 1) +def test_above_annotation_limit_divided_among_different_rows( + annotations_by_media_type, +): + annotations_ndjson = list( + itertools.chain.from_iterable( + annotations_by_media_type[MediaType.Image] + ) + ) data_row_id = annotations_ndjson[0]["dataRow"]["id"] - - first_data_row_annotation = [annotation for annotation in annotations_ndjson if annotation["dataRow"]["id"] == data_row_id and "bbox" in annotation][0] - + + first_data_row_annotation = [ + annotation + for annotation in annotations_ndjson + if annotation["dataRow"]["id"] == data_row_id and "bbox" in annotation + ][0] + second_data_row_annotation = first_data_row_annotation.copy() second_data_row_annotation["dataRow"]["id"] == "data_row_id_2" - + with pytest.raises(ValueError): - AnnotationImport._validate_data_rows([first_data_row_annotation, second_data_row_annotation]*2) \ No newline at end of file + AnnotationImport._validate_data_rows( + [first_data_row_annotation, second_data_row_annotation] * 2 + ) diff --git a/libs/labelbox/tests/data/annotation_import/test_bulk_import_request.py b/libs/labelbox/tests/data/annotation_import/test_bulk_import_request.py index 9e9abd47f..9abae1422 100644 --- a/libs/labelbox/tests/data/annotation_import/test_bulk_import_request.py +++ b/libs/labelbox/tests/data/annotation_import/test_bulk_import_request.py @@ -1,17 +1,30 @@ from unittest.mock import patch import uuid from labelbox import parser, Project -from labelbox.data.annotation_types.data.generic_data_row_data import GenericDataRowData +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) import pytest import random from labelbox.data.annotation_types.annotation import ObjectAnnotation -from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnnotation, ClassificationAnswer, Radio +from labelbox.data.annotation_types.classification.classification import ( + Checklist, + ClassificationAnnotation, + ClassificationAnswer, + Radio, +) from labelbox.data.annotation_types.data.video import VideoData from labelbox.data.annotation_types.geometry.point import Point -from labelbox.data.annotation_types.geometry.rectangle import Rectangle, RectangleUnit +from labelbox.data.annotation_types.geometry.rectangle import ( + Rectangle, + RectangleUnit, +) from labelbox.data.annotation_types.label import Label from labelbox.data.annotation_types.data.text import TextData -from labelbox.data.annotation_types.ner import DocumentEntity, DocumentTextSelection +from labelbox.data.annotation_types.ner import ( + DocumentEntity, + DocumentTextSelection, +) from labelbox.data.annotation_types.video import VideoObjectAnnotation from labelbox.data.serialization import NDJsonConverter @@ -20,20 +33,22 @@ from labelbox.schema.enums import BulkImportRequestState from labelbox.schema.annotation_import import LabelImport, MALPredictionImport from labelbox.schema.media_type import MediaType + """ - Here we only want to check that the uploads are calling the validation - Then with unit tests we can check the types of errors raised """ -#TODO: remove library once bulk import requests are removed +# TODO: remove library once bulk import requests are removed + @pytest.mark.order(1) def test_create_from_url(module_project): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" - bulk_import_request = module_project.upload_annotations(name=name, - annotations=url, - validate=False) + bulk_import_request = module_project.upload_annotations( + name=name, annotations=url, validate=False + ) assert bulk_import_request.project() == module_project assert bulk_import_request.name == name @@ -47,18 +62,20 @@ def test_validate_file(module_project): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" with pytest.raises(MALValidationError): - module_project.upload_annotations(name=name, - annotations=url, - validate=True) - #Schema ids shouldn't match + module_project.upload_annotations( + name=name, annotations=url, validate=True + ) + # Schema ids shouldn't match -def test_create_from_objects(module_project: Project, predictions, - annotation_import_test_helpers): +def test_create_from_objects( + module_project: Project, predictions, annotation_import_test_helpers +): name = str(uuid.uuid4()) bulk_import_request = module_project.upload_annotations( - name=name, annotations=predictions) + name=name, annotations=predictions + ) assert bulk_import_request.project() == module_project assert bulk_import_request.name == name @@ -66,16 +83,19 @@ def test_create_from_objects(module_project: Project, predictions, assert bulk_import_request.status_file_url is None assert bulk_import_request.state == BulkImportRequestState.RUNNING annotation_import_test_helpers.assert_file_content( - bulk_import_request.input_file_url, predictions) + bulk_import_request.input_file_url, predictions + ) -def test_create_from_label_objects(module_project, predictions, - annotation_import_test_helpers): +def test_create_from_label_objects( + module_project, predictions, annotation_import_test_helpers +): name = str(uuid.uuid4()) labels = list(NDJsonConverter.deserialize(predictions)) bulk_import_request = module_project.upload_annotations( - name=name, annotations=labels) + name=name, annotations=labels + ) assert bulk_import_request.project() == module_project assert bulk_import_request.name == name @@ -84,11 +104,13 @@ def test_create_from_label_objects(module_project, predictions, assert bulk_import_request.state == BulkImportRequestState.RUNNING normalized_predictions = list(NDJsonConverter.serialize(labels)) annotation_import_test_helpers.assert_file_content( - bulk_import_request.input_file_url, normalized_predictions) + bulk_import_request.input_file_url, normalized_predictions + ) -def test_create_from_local_file(tmp_path, predictions, module_project, - annotation_import_test_helpers): +def test_create_from_local_file( + tmp_path, predictions, module_project, annotation_import_test_helpers +): name = str(uuid.uuid4()) file_name = f"{name}.ndjson" file_path = tmp_path / file_name @@ -96,7 +118,8 @@ def test_create_from_local_file(tmp_path, predictions, module_project, parser.dump(predictions, f) bulk_import_request = module_project.upload_annotations( - name=name, annotations=str(file_path), validate=False) + name=name, annotations=str(file_path), validate=False + ) assert bulk_import_request.project() == module_project assert bulk_import_request.name == name @@ -104,18 +127,20 @@ def test_create_from_local_file(tmp_path, predictions, module_project, assert bulk_import_request.status_file_url is None assert bulk_import_request.state == BulkImportRequestState.RUNNING annotation_import_test_helpers.assert_file_content( - bulk_import_request.input_file_url, predictions) + bulk_import_request.input_file_url, predictions + ) def test_get(client, module_project): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" - module_project.upload_annotations(name=name, - annotations=url, - validate=False) + module_project.upload_annotations( + name=name, annotations=url, validate=False + ) bulk_import_request = BulkImportRequest.from_name( - client, project_id=module_project.uid, name=name) + client, project_id=module_project.uid, name=name + ) assert bulk_import_request.project() == module_project assert bulk_import_request.name == name @@ -133,7 +158,8 @@ def test_validate_ndjson(tmp_path, module_project): with pytest.raises(ValueError): module_project.upload_annotations( - name="name", validate=True, annotations=str(file_path)) + name="name", validate=True, annotations=str(file_path) + ) def test_validate_ndjson_uuid(tmp_path, module_project, predictions): @@ -141,31 +167,34 @@ def test_validate_ndjson_uuid(tmp_path, module_project, predictions): file_path = tmp_path / file_name repeat_uuid = predictions.copy() uid = str(uuid.uuid4()) - repeat_uuid[0]['uuid'] = uid - repeat_uuid[1]['uuid'] = uid + repeat_uuid[0]["uuid"] = uid + repeat_uuid[1]["uuid"] = uid with file_path.open("w") as f: parser.dump(repeat_uuid, f) with pytest.raises(UuidError): - module_project.upload_annotations(name="name", - validate=True, - annotations=str(file_path)) + module_project.upload_annotations( + name="name", validate=True, annotations=str(file_path) + ) with pytest.raises(UuidError): - module_project.upload_annotations(name="name", - validate=True, - annotations=repeat_uuid) + module_project.upload_annotations( + name="name", validate=True, annotations=repeat_uuid + ) -@pytest.mark.skip("Slow test and uses a deprecated api endpoint for annotation imports") -def test_wait_till_done(rectangle_inference, - project): +@pytest.mark.skip( + "Slow test and uses a deprecated api endpoint for annotation imports" +) +def test_wait_till_done(rectangle_inference, project): name = str(uuid.uuid4()) url = project.client.upload_data( - content=parser.dumps(rectangle_inference), sign=True) + content=parser.dumps(rectangle_inference), sign=True + ) bulk_import_request = project.upload_annotations( - name=name, annotations=url, validate=False) + name=name, annotations=url, validate=False + ) assert len(bulk_import_request.inputs) == 1 bulk_import_request.wait_until_done() @@ -174,11 +203,12 @@ def test_wait_till_done(rectangle_inference, # Check that the status files are being returned as expected assert len(bulk_import_request.errors) == 0 assert len(bulk_import_request.inputs) == 1 - assert bulk_import_request.inputs[0]['uuid'] == rectangle_inference['uuid'] + assert bulk_import_request.inputs[0]["uuid"] == rectangle_inference["uuid"] assert len(bulk_import_request.statuses) == 1 - assert bulk_import_request.statuses[0]['status'] == 'SUCCESS' - assert bulk_import_request.statuses[0]['uuid'] == rectangle_inference[ - 'uuid'] + assert bulk_import_request.statuses[0]["status"] == "SUCCESS" + assert ( + bulk_import_request.statuses[0]["uuid"] == rectangle_inference["uuid"] + ) def test_project_bulk_import_requests(module_project, predictions): @@ -187,17 +217,20 @@ def test_project_bulk_import_requests(module_project, predictions): name = str(uuid.uuid4()) bulk_import_request = module_project.upload_annotations( - name=name, annotations=predictions) + name=name, annotations=predictions + ) bulk_import_request.wait_until_done() name = str(uuid.uuid4()) bulk_import_request = module_project.upload_annotations( - name=name, annotations=predictions) + name=name, annotations=predictions + ) bulk_import_request.wait_until_done() name = str(uuid.uuid4()) bulk_import_request = module_project.upload_annotations( - name=name, annotations=predictions) + name=name, annotations=predictions + ) bulk_import_request.wait_until_done() result = module_project.bulk_import_requests() @@ -206,12 +239,16 @@ def test_project_bulk_import_requests(module_project, predictions): def test_delete(module_project, predictions): name = str(uuid.uuid4()) - + bulk_import_requests = module_project.bulk_import_requests() - [bulk_import_request.delete() for bulk_import_request in bulk_import_requests] - + [ + bulk_import_request.delete() + for bulk_import_request in bulk_import_requests + ] + bulk_import_request = module_project.upload_annotations( - name=name, annotations=predictions) + name=name, annotations=predictions + ) bulk_import_request.wait_until_done() all_import_requests = module_project.bulk_import_requests() assert len(list(all_import_requests)) == 1 diff --git a/libs/labelbox/tests/data/annotation_import/test_data_types.py b/libs/labelbox/tests/data/annotation_import/test_data_types.py index d7b3ef825..1e45295ef 100644 --- a/libs/labelbox/tests/data/annotation_import/test_data_types.py +++ b/libs/labelbox/tests/data/annotation_import/test_data_types.py @@ -37,13 +37,15 @@ def test_data_row_type_by_data_row_id( annotations_by_media_type, hardcoded_datarow_id, ): - annotations_ndjson = annotations_by_media_type[media_type] + annotations_ndjson = annotations_by_media_type[media_type] annotations_ndjson = [annotation[0] for annotation in annotations_ndjson] - + label = list(NDJsonConverter.deserialize(annotations_ndjson))[0] - - data_label = lb_types.Label(data=data_type_class(uid = hardcoded_datarow_id()), - annotations=label.annotations) + + data_label = lb_types.Label( + data=data_type_class(uid=hardcoded_datarow_id()), + annotations=label.annotations, + ) assert data_label.data.uid == label.data.uid assert label.annotations == data_label.annotations @@ -67,13 +69,15 @@ def test_data_row_type_by_global_key( annotations_by_media_type, hardcoded_global_key, ): - annotations_ndjson = annotations_by_media_type[media_type] + annotations_ndjson = annotations_by_media_type[media_type] annotations_ndjson = [annotation[0] for annotation in annotations_ndjson] - + label = list(NDJsonConverter.deserialize(annotations_ndjson))[0] - - data_label = lb_types.Label(data=data_type_class(global_key = hardcoded_global_key()), - annotations=label.annotations) + + data_label = lb_types.Label( + data=data_type_class(global_key=hardcoded_global_key()), + annotations=label.annotations, + ) assert data_label.data.global_key == label.data.global_key assert label.annotations == data_label.annotations diff --git a/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py b/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py index fa2c9e3f8..f8f0c449a 100644 --- a/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py +++ b/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py @@ -1,5 +1,7 @@ import datetime -from labelbox.data.annotation_types.data.generic_data_row_data import GenericDataRowData +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) from labelbox.data.serialization.ndjson.converter import NDJsonConverter from labelbox.data.annotation_types import Label import pytest @@ -10,6 +12,7 @@ from labelbox.schema.annotation_import import AnnotationImportState from labelbox import Project, Client, OntologyKind import itertools + """ - integration test for importing mal labels and ground truths with each supported MediaType. - NDJSON is used to generate annotations. @@ -18,7 +21,8 @@ def validate_iso_format(date_string: str): parsed_t = datetime.datetime.fromisoformat( - date_string) # this will blow up if the string is not in iso format + date_string + ) # this will blow up if the string is not in iso format assert parsed_t.hour is not None assert parsed_t.minute is not None assert parsed_t.second is not None @@ -26,16 +30,18 @@ def validate_iso_format(date_string: str): @pytest.mark.parametrize( "media_type, data_type_class", - [(MediaType.Audio, GenericDataRowData), - (MediaType.Html, GenericDataRowData), - (MediaType.Image, GenericDataRowData), - (MediaType.Text, GenericDataRowData), - (MediaType.Video, GenericDataRowData), - (MediaType.Conversational, GenericDataRowData), - (MediaType.Document, GenericDataRowData), - (MediaType.LLMPromptResponseCreation, GenericDataRowData), - (MediaType.LLMPromptCreation, GenericDataRowData), - (OntologyKind.ResponseCreation, GenericDataRowData)], + [ + (MediaType.Audio, GenericDataRowData), + (MediaType.Html, GenericDataRowData), + (MediaType.Image, GenericDataRowData), + (MediaType.Text, GenericDataRowData), + (MediaType.Video, GenericDataRowData), + (MediaType.Conversational, GenericDataRowData), + (MediaType.Document, GenericDataRowData), + (MediaType.LLMPromptResponseCreation, GenericDataRowData), + (MediaType.LLMPromptCreation, GenericDataRowData), + (OntologyKind.ResponseCreation, GenericDataRowData), + ], ) def test_generic_data_row_type_by_data_row_id( media_type, @@ -48,8 +54,10 @@ def test_generic_data_row_type_by_data_row_id( label = list(NDJsonConverter.deserialize(annotations_ndjson))[0] - data_label = Label(data=data_type_class(uid=hardcoded_datarow_id()), - annotations=label.annotations) + data_label = Label( + data=data_type_class(uid=hardcoded_datarow_id()), + annotations=label.annotations, + ) assert data_label.data.uid == label.data.uid assert label.annotations == data_label.annotations @@ -67,7 +75,7 @@ def test_generic_data_row_type_by_data_row_id( (MediaType.Document, GenericDataRowData), # (MediaType.LLMPromptResponseCreation, GenericDataRowData), # (MediaType.LLMPromptCreation, GenericDataRowData), - (OntologyKind.ResponseCreation, GenericDataRowData) + (OntologyKind.ResponseCreation, GenericDataRowData), ], ) def test_generic_data_row_type_by_global_key( @@ -81,8 +89,10 @@ def test_generic_data_row_type_by_global_key( label = list(NDJsonConverter.deserialize(annotations_ndjson))[0] - data_label = Label(data=data_type_class(global_key=hardcoded_global_key()), - annotations=label.annotations) + data_label = Label( + data=data_type_class(global_key=hardcoded_global_key()), + annotations=label.annotations, + ) assert data_label.data.global_key == label.data.global_key assert label.annotations == data_label.annotations @@ -90,16 +100,24 @@ def test_generic_data_row_type_by_global_key( @pytest.mark.parametrize( "configured_project, media_type", - [(MediaType.Audio, MediaType.Audio), (MediaType.Html, MediaType.Html), - (MediaType.Image, MediaType.Image), (MediaType.Text, MediaType.Text), - (MediaType.Video, MediaType.Video), - (MediaType.Conversational, MediaType.Conversational), - (MediaType.Document, MediaType.Document), - (MediaType.Dicom, MediaType.Dicom), - (MediaType.LLMPromptResponseCreation, MediaType.LLMPromptResponseCreation), - (MediaType.LLMPromptCreation, MediaType.LLMPromptCreation), - (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation)], - indirect=["configured_project"]) + [ + (MediaType.Audio, MediaType.Audio), + (MediaType.Html, MediaType.Html), + (MediaType.Image, MediaType.Image), + (MediaType.Text, MediaType.Text), + (MediaType.Video, MediaType.Video), + (MediaType.Conversational, MediaType.Conversational), + (MediaType.Document, MediaType.Document), + (MediaType.Dicom, MediaType.Dicom), + ( + MediaType.LLMPromptResponseCreation, + MediaType.LLMPromptResponseCreation, + ), + (MediaType.LLMPromptCreation, MediaType.LLMPromptCreation), + (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation), + ], + indirect=["configured_project"], +) def test_import_media_types( client: Client, configured_project: Project, @@ -110,18 +128,23 @@ def test_import_media_types( media_type, ): annotations_ndjson = list( - itertools.chain.from_iterable(annotations_by_media_type[media_type])) + itertools.chain.from_iterable(annotations_by_media_type[media_type]) + ) label_import = lb.LabelImport.create_from_objects( - client, configured_project.uid, f"test-import-{media_type}", - annotations_ndjson) + client, + configured_project.uid, + f"test-import-{media_type}", + annotations_ndjson, + ) label_import.wait_until_done() assert label_import.state == AnnotationImportState.FINISHED assert len(label_import.errors) == 0 result = export_v2_test_helpers.run_project_export_v2_task( - configured_project) + configured_project + ) assert result @@ -129,20 +152,28 @@ def test_import_media_types( # timestamp fields are in iso format validate_iso_format(exported_data["data_row"]["details"]["created_at"]) validate_iso_format(exported_data["data_row"]["details"]["updated_at"]) - validate_iso_format(exported_data["projects"][configured_project.uid] - ["labels"][0]["label_details"]["created_at"]) - validate_iso_format(exported_data["projects"][configured_project.uid] - ["labels"][0]["label_details"]["updated_at"]) - - assert exported_data["data_row"][ - "id"] in configured_project.data_row_ids + validate_iso_format( + exported_data["projects"][configured_project.uid]["labels"][0][ + "label_details" + ]["created_at"] + ) + validate_iso_format( + exported_data["projects"][configured_project.uid]["labels"][0][ + "label_details" + ]["updated_at"] + ) + + assert ( + exported_data["data_row"]["id"] in configured_project.data_row_ids + ) exported_project = exported_data["projects"][configured_project.uid] exported_project_labels = exported_project["labels"][0] exported_annotations = exported_project_labels["annotations"] expected_data = exports_v2_by_media_type[media_type] - helpers.remove_keys_recursive(exported_annotations, - ["feature_id", "feature_schema_id"]) + helpers.remove_keys_recursive( + exported_annotations, ["feature_id", "feature_schema_id"] + ) helpers.rename_cuid_key_recursive(exported_annotations) assert exported_annotations == expected_data @@ -150,30 +181,46 @@ def test_import_media_types( @pytest.mark.parametrize( "configured_project_by_global_key, media_type", - [(MediaType.Audio, MediaType.Audio), (MediaType.Html, MediaType.Html), - (MediaType.Image, MediaType.Image), (MediaType.Text, MediaType.Text), - (MediaType.Video, MediaType.Video), - (MediaType.Conversational, MediaType.Conversational), - (MediaType.Document, MediaType.Document), - (MediaType.Dicom, MediaType.Dicom), - (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation)], - indirect=["configured_project_by_global_key"]) + [ + (MediaType.Audio, MediaType.Audio), + (MediaType.Html, MediaType.Html), + (MediaType.Image, MediaType.Image), + (MediaType.Text, MediaType.Text), + (MediaType.Video, MediaType.Video), + (MediaType.Conversational, MediaType.Conversational), + (MediaType.Document, MediaType.Document), + (MediaType.Dicom, MediaType.Dicom), + (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation), + ], + indirect=["configured_project_by_global_key"], +) def test_import_media_types_by_global_key( - client, configured_project_by_global_key, annotations_by_media_type, - exports_v2_by_media_type, export_v2_test_helpers, helpers, media_type): + client, + configured_project_by_global_key, + annotations_by_media_type, + exports_v2_by_media_type, + export_v2_test_helpers, + helpers, + media_type, +): annotations_ndjson = list( - itertools.chain.from_iterable(annotations_by_media_type[media_type])) + itertools.chain.from_iterable(annotations_by_media_type[media_type]) + ) label_import = lb.LabelImport.create_from_objects( - client, configured_project_by_global_key.uid, - f"test-import-{media_type}", annotations_ndjson) + client, + configured_project_by_global_key.uid, + f"test-import-{media_type}", + annotations_ndjson, + ) label_import.wait_until_done() assert label_import.state == AnnotationImportState.FINISHED assert len(label_import.errors) == 0 result = export_v2_test_helpers.run_project_export_v2_task( - configured_project_by_global_key) + configured_project_by_global_key + ) assert result @@ -182,22 +229,30 @@ def test_import_media_types_by_global_key( validate_iso_format(exported_data["data_row"]["details"]["created_at"]) validate_iso_format(exported_data["data_row"]["details"]["updated_at"]) validate_iso_format( - exported_data["projects"][configured_project_by_global_key.uid] - ["labels"][0]["label_details"]["created_at"]) + exported_data["projects"][configured_project_by_global_key.uid][ + "labels" + ][0]["label_details"]["created_at"] + ) validate_iso_format( - exported_data["projects"][configured_project_by_global_key.uid] - ["labels"][0]["label_details"]["updated_at"]) - - assert exported_data["data_row"][ - "id"] in configured_project_by_global_key.data_row_ids + exported_data["projects"][configured_project_by_global_key.uid][ + "labels" + ][0]["label_details"]["updated_at"] + ) + + assert ( + exported_data["data_row"]["id"] + in configured_project_by_global_key.data_row_ids + ) exported_project = exported_data["projects"][ - configured_project_by_global_key.uid] + configured_project_by_global_key.uid + ] exported_project_labels = exported_project["labels"][0] exported_annotations = exported_project_labels["annotations"] expected_data = exports_v2_by_media_type[media_type] - helpers.remove_keys_recursive(exported_annotations, - ["feature_id", "feature_schema_id"]) + helpers.remove_keys_recursive( + exported_annotations, ["feature_id", "feature_schema_id"] + ) helpers.rename_cuid_key_recursive(exported_annotations) assert exported_annotations == expected_data @@ -214,15 +269,21 @@ def test_import_media_types_by_global_key( (MediaType.Conversational, MediaType.Conversational), (MediaType.Document, MediaType.Document), (MediaType.Dicom, MediaType.Dicom), - (MediaType.LLMPromptResponseCreation, MediaType.LLMPromptResponseCreation), + ( + MediaType.LLMPromptResponseCreation, + MediaType.LLMPromptResponseCreation, + ), (MediaType.LLMPromptCreation, MediaType.LLMPromptCreation), - (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation) + (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation), ], - indirect=["configured_project"]) -def test_import_mal_annotations(client, configured_project: Project, - annotations_by_media_type, media_type): + indirect=["configured_project"], +) +def test_import_mal_annotations( + client, configured_project: Project, annotations_by_media_type, media_type +): annotations_ndjson = list( - itertools.chain.from_iterable(annotations_by_media_type[media_type])) + itertools.chain.from_iterable(annotations_by_media_type[media_type]) + ) import_annotations = lb.MALPredictionImport.create_from_objects( client=client, @@ -238,20 +299,28 @@ def test_import_mal_annotations(client, configured_project: Project, @pytest.mark.parametrize( "configured_project_by_global_key, media_type", - [(MediaType.Audio, MediaType.Audio), (MediaType.Html, MediaType.Html), - (MediaType.Image, MediaType.Image), (MediaType.Text, MediaType.Text), - (MediaType.Video, MediaType.Video), - (MediaType.Conversational, MediaType.Conversational), - (MediaType.Document, MediaType.Document), - (MediaType.Dicom, MediaType.Dicom), - (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation)], - indirect=["configured_project_by_global_key"]) + [ + (MediaType.Audio, MediaType.Audio), + (MediaType.Html, MediaType.Html), + (MediaType.Image, MediaType.Image), + (MediaType.Text, MediaType.Text), + (MediaType.Video, MediaType.Video), + (MediaType.Conversational, MediaType.Conversational), + (MediaType.Document, MediaType.Document), + (MediaType.Dicom, MediaType.Dicom), + (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation), + ], + indirect=["configured_project_by_global_key"], +) def test_import_mal_annotations_global_key( - client, configured_project_by_global_key: Project, - annotations_by_media_type, media_type): - + client, + configured_project_by_global_key: Project, + annotations_by_media_type, + media_type, +): annotations_ndjson = list( - itertools.chain.from_iterable(annotations_by_media_type[media_type])) + itertools.chain.from_iterable(annotations_by_media_type[media_type]) + ) import_annotations = lb.MALPredictionImport.create_from_objects( client=client, diff --git a/libs/labelbox/tests/data/annotation_import/test_label_import.py b/libs/labelbox/tests/data/annotation_import/test_label_import.py index 50b701813..5576025fd 100644 --- a/libs/labelbox/tests/data/annotation_import/test_label_import.py +++ b/libs/labelbox/tests/data/annotation_import/test_label_import.py @@ -3,6 +3,7 @@ from labelbox import parser from labelbox.schema.annotation_import import AnnotationImportState, LabelImport + """ - Here we only want to check that the uploads are calling the validation - Then with unit tests we can check the types of errors raised @@ -10,50 +11,53 @@ """ -def test_create_with_url_arg(client, module_project, - annotation_import_test_helpers): +def test_create_with_url_arg( + client, module_project, annotation_import_test_helpers +): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" label_import = LabelImport.create( - client=client, - id=module_project.uid, - name=name, - url=url) + client=client, id=module_project.uid, name=name, url=url + ) assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name, url) -def test_create_from_url(client, module_project, - annotation_import_test_helpers): +def test_create_from_url( + client, module_project, annotation_import_test_helpers +): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" label_import = LabelImport.create_from_url( - client=client, - project_id=module_project.uid, - name=name, - url=url) + client=client, project_id=module_project.uid, name=name, url=url + ) assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name, url) -def test_create_with_labels_arg(client, module_project, object_predictions, - annotation_import_test_helpers): +def test_create_with_labels_arg( + client, module_project, object_predictions, annotation_import_test_helpers +): """this test should check running state only to validate running, not completed""" name = str(uuid.uuid4()) - label_import = LabelImport.create(client=client, - id=module_project.uid, - name=name, - labels=object_predictions) + label_import = LabelImport.create( + client=client, + id=module_project.uid, + name=name, + labels=object_predictions, + ) assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name) annotation_import_test_helpers.assert_file_content( - label_import.input_file_url, object_predictions) + label_import.input_file_url, object_predictions + ) -def test_create_from_objects(client, module_project, object_predictions, - annotation_import_test_helpers): +def test_create_from_objects( + client, module_project, object_predictions, annotation_import_test_helpers +): """this test should check running state only to validate running, not completed""" name = str(uuid.uuid4()) @@ -61,16 +65,23 @@ def test_create_from_objects(client, module_project, object_predictions, client=client, project_id=module_project.uid, name=name, - labels=object_predictions) + labels=object_predictions, + ) assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name) annotation_import_test_helpers.assert_file_content( - label_import.input_file_url, object_predictions) + label_import.input_file_url, object_predictions + ) -def test_create_with_path_arg(client, tmp_path, module_project, object_predictions, - annotation_import_test_helpers): +def test_create_with_path_arg( + client, + tmp_path, + module_project, + object_predictions, + annotation_import_test_helpers, +): project = module_project name = str(uuid.uuid4()) file_name = f"{name}.ndjson" @@ -78,19 +89,24 @@ def test_create_with_path_arg(client, tmp_path, module_project, object_predictio with file_path.open("w") as f: parser.dump(object_predictions, f) - label_import = LabelImport.create(client=client, - id=project.uid, - name=name, - path=str(file_path)) + label_import = LabelImport.create( + client=client, id=project.uid, name=name, path=str(file_path) + ) assert label_import.parent_id == project.uid annotation_import_test_helpers.check_running_state(label_import, name) annotation_import_test_helpers.assert_file_content( - label_import.input_file_url, object_predictions) + label_import.input_file_url, object_predictions + ) -def test_create_from_local_file(client, tmp_path, module_project, object_predictions, - annotation_import_test_helpers): +def test_create_from_local_file( + client, + tmp_path, + module_project, + object_predictions, + annotation_import_test_helpers, +): project = module_project name = str(uuid.uuid4()) file_name = f"{name}.ndjson" @@ -98,26 +114,23 @@ def test_create_from_local_file(client, tmp_path, module_project, object_predict with file_path.open("w") as f: parser.dump(object_predictions, f) - label_import = LabelImport.create_from_file(client=client, - project_id=project.uid, - name=name, - path=str(file_path)) + label_import = LabelImport.create_from_file( + client=client, project_id=project.uid, name=name, path=str(file_path) + ) assert label_import.parent_id == project.uid annotation_import_test_helpers.check_running_state(label_import, name) annotation_import_test_helpers.assert_file_content( - label_import.input_file_url, object_predictions) + label_import.input_file_url, object_predictions + ) -def test_get(client, module_project, - annotation_import_test_helpers): +def test_get(client, module_project, annotation_import_test_helpers): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" label_import = LabelImport.create_from_url( - client=client, - project_id=module_project.uid, - name=name, - url=url) + client=client, project_id=module_project.uid, name=name, url=url + ) assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name, url) @@ -130,18 +143,19 @@ def test_wait_till_done(client, module_project, predictions): client=client, project_id=module_project.uid, name=name, - labels=predictions) + labels=predictions, + ) assert len(label_import.inputs) == len(predictions) label_import.wait_until_done() assert label_import.state == AnnotationImportState.FINISHED assert len(label_import.inputs) == len(predictions) - input_uuids = [input_annot['uuid'] for input_annot in label_import.inputs] - inference_uuids = [pred['uuid'] for pred in predictions] + input_uuids = [input_annot["uuid"] for input_annot in label_import.inputs] + inference_uuids = [pred["uuid"] for pred in predictions] assert set(input_uuids) == set(inference_uuids) assert len(label_import.statuses) == len(predictions) status_uuids = [ - input_annot['uuid'] for input_annot in label_import.statuses + input_annot["uuid"] for input_annot in label_import.statuses ] assert set(input_uuids) == set(status_uuids) diff --git a/libs/labelbox/tests/data/annotation_import/test_mal_prediction_import.py b/libs/labelbox/tests/data/annotation_import/test_mal_prediction_import.py index c50c82315..3ffd6bfc1 100644 --- a/libs/labelbox/tests/data/annotation_import/test_mal_prediction_import.py +++ b/libs/labelbox/tests/data/annotation_import/test_mal_prediction_import.py @@ -2,6 +2,7 @@ from labelbox import parser from labelbox.schema.annotation_import import MALPredictionImport + """ - Here we only want to check that the uploads are calling the validation - Then with unit tests we can check the types of errors raised @@ -9,37 +10,45 @@ """ -def test_create_with_url_arg(client, module_project, - annotation_import_test_helpers): +def test_create_with_url_arg( + client, module_project, annotation_import_test_helpers +): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" label_import = MALPredictionImport.create( - client=client, - id=module_project.uid, - name=name, - url=url) + client=client, id=module_project.uid, name=name, url=url + ) assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name, url) -def test_create_with_labels_arg(client, module_project, object_predictions, - annotation_import_test_helpers): +def test_create_with_labels_arg( + client, module_project, object_predictions, annotation_import_test_helpers +): """this test should check running state only to validate running, not completed""" name = str(uuid.uuid4()) - label_import = MALPredictionImport.create(client=client, - id=module_project.uid, - name=name, - labels=object_predictions) + label_import = MALPredictionImport.create( + client=client, + id=module_project.uid, + name=name, + labels=object_predictions, + ) assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name) annotation_import_test_helpers.assert_file_content( - label_import.input_file_url, object_predictions) + label_import.input_file_url, object_predictions + ) -def test_create_with_path_arg(client, tmp_path, configured_project, object_predictions, - annotation_import_test_helpers): +def test_create_with_path_arg( + client, + tmp_path, + configured_project, + object_predictions, + annotation_import_test_helpers, +): project = configured_project name = str(uuid.uuid4()) file_name = f"{name}.ndjson" @@ -47,12 +56,12 @@ def test_create_with_path_arg(client, tmp_path, configured_project, object_predi with file_path.open("w") as f: parser.dump(object_predictions, f) - label_import = MALPredictionImport.create(client=client, - id=project.uid, - name=name, - path=str(file_path)) + label_import = MALPredictionImport.create( + client=client, id=project.uid, name=name, path=str(file_path) + ) assert label_import.parent_id == project.uid annotation_import_test_helpers.check_running_state(label_import, name) annotation_import_test_helpers.assert_file_content( - label_import.input_file_url, object_predictions) + label_import.input_file_url, object_predictions + ) diff --git a/libs/labelbox/tests/data/annotation_import/test_mea_prediction_import.py b/libs/labelbox/tests/data/annotation_import/test_mea_prediction_import.py index f2765fd3f..fccca2a3f 100644 --- a/libs/labelbox/tests/data/annotation_import/test_mea_prediction_import.py +++ b/libs/labelbox/tests/data/annotation_import/test_mea_prediction_import.py @@ -3,125 +3,160 @@ import pytest from labelbox import ModelRun -from labelbox.schema.annotation_import import AnnotationImportState, MEAPredictionImport +from labelbox.schema.annotation_import import ( + AnnotationImportState, + MEAPredictionImport, +) from labelbox.data.serialization import NDJsonConverter from labelbox.schema.export_params import ModelRunExportParams + """ - Here we only want to check that the uploads are calling the validation - Then with unit tests we can check the types of errors raised """ + @pytest.mark.order(1) -def test_create_from_objects(model_run_with_data_rows, - object_predictions_for_annotation_import, - annotation_import_test_helpers): +def test_create_from_objects( + model_run_with_data_rows, + object_predictions_for_annotation_import, + annotation_import_test_helpers, +): name = str(uuid.uuid4()) object_predictions = object_predictions_for_annotation_import - use_data_row_ids = [p['dataRow']['id'] for p in object_predictions] + use_data_row_ids = [p["dataRow"]["id"] for p in object_predictions] model_run_with_data_rows.upsert_data_rows(use_data_row_ids) annotation_import = model_run_with_data_rows.add_predictions( - name=name, predictions=object_predictions) + name=name, predictions=object_predictions + ) assert annotation_import.model_run_id == model_run_with_data_rows.uid annotation_import_test_helpers.check_running_state(annotation_import, name) annotation_import_test_helpers.assert_file_content( - annotation_import.input_file_url, object_predictions) + annotation_import.input_file_url, object_predictions + ) annotation_import.wait_until_done() assert annotation_import.state == AnnotationImportState.FINISHED annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) + annotation_import.status_file_url + ) -def test_create_from_objects_global_key(client, model_run_with_data_rows, - polygon_inference, - annotation_import_test_helpers): +def test_create_from_objects_global_key( + client, + model_run_with_data_rows, + polygon_inference, + annotation_import_test_helpers, +): name = str(uuid.uuid4()) - dr = client.get_data_row(polygon_inference[0]['dataRow']['id']) - polygon_inference[0]['dataRow']['globalKey'] = dr.global_key - del polygon_inference[0]['dataRow']['id'] + dr = client.get_data_row(polygon_inference[0]["dataRow"]["id"]) + polygon_inference[0]["dataRow"]["globalKey"] = dr.global_key + del polygon_inference[0]["dataRow"]["id"] object_predictions = [polygon_inference[0]] annotation_import = model_run_with_data_rows.add_predictions( - name=name, predictions=object_predictions) + name=name, predictions=object_predictions + ) assert annotation_import.model_run_id == model_run_with_data_rows.uid annotation_import_test_helpers.check_running_state(annotation_import, name) annotation_import_test_helpers.assert_file_content( - annotation_import.input_file_url, object_predictions) + annotation_import.input_file_url, object_predictions + ) annotation_import.wait_until_done() assert annotation_import.state == AnnotationImportState.FINISHED annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) + annotation_import.status_file_url + ) -def test_create_from_objects_with_confidence(predictions_with_confidence, - model_run_with_data_rows, - annotation_import_test_helpers): +def test_create_from_objects_with_confidence( + predictions_with_confidence, + model_run_with_data_rows, + annotation_import_test_helpers, +): name = str(uuid.uuid4()) - object_prediction_data_rows = set([ - object_prediction["dataRow"]["id"] - for object_prediction in predictions_with_confidence - ]) + object_prediction_data_rows = set( + [ + object_prediction["dataRow"]["id"] + for object_prediction in predictions_with_confidence + ] + ) # MUST have all data rows in the model run model_run_with_data_rows.upsert_data_rows( - data_row_ids=list(object_prediction_data_rows)) + data_row_ids=list(object_prediction_data_rows) + ) annotation_import = model_run_with_data_rows.add_predictions( - name=name, predictions=predictions_with_confidence) + name=name, predictions=predictions_with_confidence + ) assert annotation_import.model_run_id == model_run_with_data_rows.uid annotation_import_test_helpers.check_running_state(annotation_import, name) annotation_import_test_helpers.assert_file_content( - annotation_import.input_file_url, predictions_with_confidence) + annotation_import.input_file_url, predictions_with_confidence + ) annotation_import.wait_until_done() assert annotation_import.state == AnnotationImportState.FINISHED annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) + annotation_import.status_file_url + ) def test_create_from_objects_all_project_labels( - model_run_with_all_project_labels, - object_predictions_for_annotation_import, - annotation_import_test_helpers): + model_run_with_all_project_labels, + object_predictions_for_annotation_import, + annotation_import_test_helpers, +): name = str(uuid.uuid4()) object_predictions = object_predictions_for_annotation_import - use_data_row_ids = [p['dataRow']['id'] for p in object_predictions] + use_data_row_ids = [p["dataRow"]["id"] for p in object_predictions] model_run_with_all_project_labels.upsert_data_rows(use_data_row_ids) annotation_import = model_run_with_all_project_labels.add_predictions( - name=name, predictions=object_predictions) + name=name, predictions=object_predictions + ) - assert annotation_import.model_run_id == model_run_with_all_project_labels.uid + assert ( + annotation_import.model_run_id == model_run_with_all_project_labels.uid + ) annotation_import_test_helpers.check_running_state(annotation_import, name) annotation_import_test_helpers.assert_file_content( - annotation_import.input_file_url, object_predictions) + annotation_import.input_file_url, object_predictions + ) annotation_import.wait_until_done() assert annotation_import.state == AnnotationImportState.FINISHED annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) + annotation_import.status_file_url + ) -def test_model_run_project_labels(model_run_with_all_project_labels: ModelRun, - model_run_predictions): - +def test_model_run_project_labels( + model_run_with_all_project_labels: ModelRun, model_run_predictions +): model_run = model_run_with_all_project_labels export_task = model_run.export() export_task.wait_till_done() stream = export_task.get_buffered_stream() - + # exports to list of tuples (data_row_id, label) needed to adapt test to export v2 instead of export v1 since data rows ids are not at label level in export v2. - model_run_exported_labels = [( - data_row.json["data_row"]["id"], - data_row.json["experiments"][model_run.model_id]["runs"][model_run.uid]["labels"][0]) - for data_row in stream] - + model_run_exported_labels = [ + ( + data_row.json["data_row"]["id"], + data_row.json["experiments"][model_run.model_id]["runs"][ + model_run.uid + ]["labels"][0], + ) + for data_row in stream + ] + labels_indexed_by_name = {} # making sure the labels are in this model run are all labels uploaded to the project @@ -130,51 +165,69 @@ def test_model_run_project_labels(model_run_with_all_project_labels: ModelRun, for data_row_id, label in model_run_exported_labels: for object in label["annotations"]["objects"]: name = object["name"] - labels_indexed_by_name[f"{name}-{data_row_id}"] = {"label": label, "data_row_id": data_row_id} - - assert (len( - labels_indexed_by_name.keys())) == len([prediction["dataRow"]["id"] for prediction in model_run_predictions]) - - expected_data_row_ids = set([prediction["dataRow"]["id"] for prediction in model_run_predictions]) - expected_objects = set([prediction["name"] for prediction in model_run_predictions]) + labels_indexed_by_name[f"{name}-{data_row_id}"] = { + "label": label, + "data_row_id": data_row_id, + } + + assert (len(labels_indexed_by_name.keys())) == len( + [prediction["dataRow"]["id"] for prediction in model_run_predictions] + ) + + expected_data_row_ids = set( + [prediction["dataRow"]["id"] for prediction in model_run_predictions] + ) + expected_objects = set( + [prediction["name"] for prediction in model_run_predictions] + ) for data_row_id, actual_label in model_run_exported_labels: assert data_row_id in expected_data_row_ids - assert len(expected_objects) == len(actual_label["annotations"]["objects"]) + assert len(expected_objects) == len( + actual_label["annotations"]["objects"] + ) - -def test_create_from_label_objects(model_run_with_data_rows, - object_predictions_for_annotation_import, - annotation_import_test_helpers): +def test_create_from_label_objects( + model_run_with_data_rows, + object_predictions_for_annotation_import, + annotation_import_test_helpers, +): name = str(uuid.uuid4()) use_data_row_ids = [ - p['dataRow']['id'] for p in object_predictions_for_annotation_import + p["dataRow"]["id"] for p in object_predictions_for_annotation_import ] model_run_with_data_rows.upsert_data_rows(use_data_row_ids) predictions = list( - NDJsonConverter.deserialize(object_predictions_for_annotation_import)) + NDJsonConverter.deserialize(object_predictions_for_annotation_import) + ) annotation_import = model_run_with_data_rows.add_predictions( - name=name, predictions=predictions) + name=name, predictions=predictions + ) assert annotation_import.model_run_id == model_run_with_data_rows.uid annotation_import_test_helpers.check_running_state(annotation_import, name) normalized_predictions = NDJsonConverter.serialize(predictions) annotation_import_test_helpers.assert_file_content( - annotation_import.input_file_url, normalized_predictions) + annotation_import.input_file_url, normalized_predictions + ) annotation_import.wait_until_done() assert annotation_import.state == AnnotationImportState.FINISHED annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) + annotation_import.status_file_url + ) -def test_create_from_local_file(tmp_path, model_run_with_data_rows, - object_predictions_for_annotation_import, - annotation_import_test_helpers): +def test_create_from_local_file( + tmp_path, + model_run_with_data_rows, + object_predictions_for_annotation_import, + annotation_import_test_helpers, +): use_data_row_ids = [ - p['dataRow']['id'] for p in object_predictions_for_annotation_import + p["dataRow"]["id"] for p in object_predictions_for_annotation_import ] model_run_with_data_rows.upsert_data_rows(use_data_row_ids) @@ -185,30 +238,36 @@ def test_create_from_local_file(tmp_path, model_run_with_data_rows, parser.dump(object_predictions_for_annotation_import, f) annotation_import = model_run_with_data_rows.add_predictions( - name=name, predictions=str(file_path)) + name=name, predictions=str(file_path) + ) assert annotation_import.model_run_id == model_run_with_data_rows.uid annotation_import_test_helpers.check_running_state(annotation_import, name) annotation_import_test_helpers.assert_file_content( annotation_import.input_file_url, - object_predictions_for_annotation_import) + object_predictions_for_annotation_import, + ) annotation_import.wait_until_done() assert annotation_import.state == AnnotationImportState.FINISHED annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) + annotation_import.status_file_url + ) def test_predictions_with_custom_metrics( - model_run, object_predictions_for_annotation_import, - annotation_import_test_helpers): + model_run, + object_predictions_for_annotation_import, + annotation_import_test_helpers, +): name = str(uuid.uuid4()) object_predictions = object_predictions_for_annotation_import - use_data_row_ids = [p['dataRow']['id'] for p in object_predictions] + use_data_row_ids = [p["dataRow"]["id"] for p in object_predictions] model_run.upsert_data_rows(use_data_row_ids) annotation_import = model_run.add_predictions( - name=name, predictions=object_predictions) + name=name, predictions=object_predictions + ) assert annotation_import.model_run_id == model_run.uid annotation_import.wait_until_done() @@ -219,7 +278,8 @@ def test_predictions_with_custom_metrics( assert annotation_import.state == AnnotationImportState.FINISHED annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) + annotation_import.status_file_url + ) def test_get(client, model_run_with_data_rows, annotation_import_test_helpers): @@ -228,11 +288,13 @@ def test_get(client, model_run_with_data_rows, annotation_import_test_helpers): model_run_with_data_rows.add_predictions(name=name, predictions=url) annotation_import = MEAPredictionImport.from_name( - client, model_run_id=model_run_with_data_rows.uid, name=name) + client, model_run_id=model_run_with_data_rows.uid, name=name + ) assert annotation_import.model_run_id == model_run_with_data_rows.uid - annotation_import_test_helpers.check_running_state(annotation_import, name, - url) + annotation_import_test_helpers.check_running_state( + annotation_import, name, url + ) annotation_import.wait_until_done() @@ -240,7 +302,8 @@ def test_get(client, model_run_with_data_rows, annotation_import_test_helpers): def test_wait_till_done(model_run_predictions, model_run_with_data_rows): name = str(uuid.uuid4()) annotation_import = model_run_with_data_rows.add_predictions( - name=name, predictions=model_run_predictions) + name=name, predictions=model_run_predictions + ) assert len(annotation_import.inputs) == len(model_run_predictions) annotation_import.wait_until_done() @@ -249,14 +312,14 @@ def test_wait_till_done(model_run_predictions, model_run_with_data_rows): assert len(annotation_import.errors) == 0 assert len(annotation_import.inputs) == len(model_run_predictions) input_uuids = [ - input_annot['uuid'] for input_annot in annotation_import.inputs + input_annot["uuid"] for input_annot in annotation_import.inputs ] - inference_uuids = [pred['uuid'] for pred in model_run_predictions] + inference_uuids = [pred["uuid"] for pred in model_run_predictions] assert set(input_uuids) == set(inference_uuids) assert len(annotation_import.statuses) == len(model_run_predictions) for status in annotation_import.statuses: - assert status['status'] == 'SUCCESS' + assert status["status"] == "SUCCESS" status_uuids = [ - input_annot['uuid'] for input_annot in annotation_import.statuses + input_annot["uuid"] for input_annot in annotation_import.statuses ] assert set(input_uuids) == set(status_uuids) diff --git a/libs/labelbox/tests/data/annotation_import/test_model_run.py b/libs/labelbox/tests/data/annotation_import/test_model_run.py index bf30ed169..9eca28429 100644 --- a/libs/labelbox/tests/data/annotation_import/test_model_run.py +++ b/libs/labelbox/tests/data/annotation_import/test_model_run.py @@ -6,6 +6,7 @@ from labelbox import DataSplit, ModelRun + @pytest.mark.order(1) def test_model_run(client, configured_project_with_label, data_row, rand_gen): project, _, _, label = configured_project_with_label @@ -87,19 +88,19 @@ def test_model_run_data_rows_delete(model_run_with_data_rows): assert len(before) == len(after) + 1 -def test_model_run_upsert_data_rows(dataset, model_run, - configured_project): +def test_model_run_upsert_data_rows(dataset, model_run, configured_project): n_model_run_data_rows = len(list(model_run.model_run_data_rows())) assert n_model_run_data_rows == 0 data_row = dataset.create_data_row(row_data="test row data") configured_project._wait_until_data_rows_are_processed( - data_row_ids=[data_row.uid]) + data_row_ids=[data_row.uid] + ) model_run.upsert_data_rows([data_row.uid]) n_model_run_data_rows = len(list(model_run.model_run_data_rows())) assert n_model_run_data_rows == 1 -@pytest.mark.parametrize('data_rows', [2], indirect=True) +@pytest.mark.parametrize("data_rows", [2], indirect=True) def test_model_run_upsert_data_rows_using_global_keys(model_run, data_rows): global_keys = [dr.global_key for dr in data_rows] assert model_run.upsert_data_rows(global_keys=global_keys) @@ -109,68 +110,77 @@ def test_model_run_upsert_data_rows_using_global_keys(model_run, data_rows): def test_model_run_upsert_data_rows_with_existing_labels( - model_run_with_data_rows): + model_run_with_data_rows, +): model_run_data_rows = list(model_run_with_data_rows.model_run_data_rows()) n_data_rows = len(model_run_data_rows) - model_run_with_data_rows.upsert_data_rows([ - model_run_data_row.data_row().uid - for model_run_data_row in model_run_data_rows - ]) + model_run_with_data_rows.upsert_data_rows( + [ + model_run_data_row.data_row().uid + for model_run_data_row in model_run_data_rows + ] + ) assert n_data_rows == len( - list(model_run_with_data_rows.model_run_data_rows())) + list(model_run_with_data_rows.model_run_data_rows()) + ) -@pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", - reason="does not work for onprem") +@pytest.mark.skipif( + condition=os.environ["LABELBOX_TEST_ENVIRON"] == "onprem", + reason="does not work for onprem", +) def test_model_run_status(model_run_with_data_rows): - def get_model_run_status(): return model_run_with_data_rows.client.execute( """query trainingPipelinePyApi($modelRunId: ID!) { trainingPipeline(where: {id : $modelRunId}) {status, errorMessage, metadata}} - """, {'modelRunId': model_run_with_data_rows.uid}, - experimental=True)['trainingPipeline'] + """, + {"modelRunId": model_run_with_data_rows.uid}, + experimental=True, + )["trainingPipeline"] model_run_status = get_model_run_status() - assert model_run_status['status'] is None - assert model_run_status['metadata'] is None - assert model_run_status['errorMessage'] is None + assert model_run_status["status"] is None + assert model_run_status["metadata"] is None + assert model_run_status["errorMessage"] is None status = "COMPLETE" - metadata = {'key1': 'value1'} + metadata = {"key1": "value1"} errorMessage = "an error" model_run_with_data_rows.update_status(status, metadata, errorMessage) model_run_status = get_model_run_status() - assert model_run_status['status'] == status - assert model_run_status['metadata'] == metadata - assert model_run_status['errorMessage'] == errorMessage + assert model_run_status["status"] == status + assert model_run_status["metadata"] == metadata + assert model_run_status["errorMessage"] == errorMessage - extra_metadata = {'key2': 'value2'} + extra_metadata = {"key2": "value2"} model_run_with_data_rows.update_status(status, extra_metadata) model_run_status = get_model_run_status() - assert model_run_status['status'] == status - assert model_run_status['metadata'] == {**metadata, **extra_metadata} - assert model_run_status['errorMessage'] == errorMessage + assert model_run_status["status"] == status + assert model_run_status["metadata"] == {**metadata, **extra_metadata} + assert model_run_status["errorMessage"] == errorMessage status = ModelRun.Status.FAILED model_run_with_data_rows.update_status(status, metadata, errorMessage) model_run_status = get_model_run_status() - assert model_run_status['status'] == status.value + assert model_run_status["status"] == status.value with pytest.raises(ValueError): - model_run_with_data_rows.update_status("INVALID", metadata, - errorMessage) + model_run_with_data_rows.update_status( + "INVALID", metadata, errorMessage + ) -def test_model_run_split_assignment_by_data_row_ids(model_run, dataset, - image_url): +def test_model_run_split_assignment_by_data_row_ids( + model_run, dataset, image_url +): n_data_rows = 2 - data_rows = dataset.create_data_rows([{ - "row_data": image_url - } for _ in range(n_data_rows)]) + data_rows = dataset.create_data_rows( + [{"row_data": image_url} for _ in range(n_data_rows)] + ) data_rows.wait_till_done() - data_row_ids = [data_row['id'] for data_row in data_rows.result] + data_row_ids = [data_row["id"] for data_row in data_rows.result] model_run.upsert_data_rows(data_row_ids) with pytest.raises(ValueError): @@ -185,15 +195,16 @@ def test_model_run_split_assignment_by_data_row_ids(model_run, dataset, assert counts[split] == n_data_rows -@pytest.mark.parametrize('data_rows', [2], indirect=True) +@pytest.mark.parametrize("data_rows", [2], indirect=True) def test_model_run_split_assignment_by_global_keys(model_run, data_rows): global_keys = [data_row.global_key for data_row in data_rows] model_run.upsert_data_rows(global_keys=global_keys) for split in ["TRAINING", "TEST", "VALIDATION", "UNASSIGNED", *DataSplit]: - model_run.assign_data_rows_to_split(split=split, - global_keys=global_keys) + model_run.assign_data_rows_to_split( + split=split, global_keys=global_keys + ) splits = [ data_row.data_split.value for data_row in model_run.model_run_data_rows() diff --git a/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py b/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py index ac197a321..a0df559fc 100644 --- a/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py +++ b/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py @@ -6,15 +6,25 @@ from pytest_cases import parametrize, fixture_ref from labelbox.exceptions import MALValidationError -from labelbox.schema.bulk_import_request import (NDChecklist, NDClassification, - NDMask, NDPolygon, NDPolyline, - NDRadio, NDRectangle, NDText, - NDTextEntity, NDTool, - _validate_ndjson) +from labelbox.schema.bulk_import_request import ( + NDChecklist, + NDClassification, + NDMask, + NDPolygon, + NDPolyline, + NDRadio, + NDRectangle, + NDText, + NDTextEntity, + NDTool, + _validate_ndjson, +) + """ - These NDlabels are apart of bulkImportReqeust and should be removed once bulk import request is removed """ + def test_classification_construction(checklist_inference, text_inference): checklist = NDClassification.build(checklist_inference[0]) assert isinstance(checklist, NDChecklist) @@ -22,97 +32,93 @@ def test_classification_construction(checklist_inference, text_inference): assert isinstance(text, NDText) -@parametrize("inference, expected_type", - [(fixture_ref('polygon_inference'), NDPolygon), - (fixture_ref('rectangle_inference'), NDRectangle), - (fixture_ref('line_inference'), NDPolyline), - (fixture_ref('entity_inference'), NDTextEntity), - (fixture_ref('segmentation_inference'), NDMask), - (fixture_ref('segmentation_inference_rle'), NDMask), - (fixture_ref('segmentation_inference_png'), NDMask)]) +@parametrize( + "inference, expected_type", + [ + (fixture_ref("polygon_inference"), NDPolygon), + (fixture_ref("rectangle_inference"), NDRectangle), + (fixture_ref("line_inference"), NDPolyline), + (fixture_ref("entity_inference"), NDTextEntity), + (fixture_ref("segmentation_inference"), NDMask), + (fixture_ref("segmentation_inference_rle"), NDMask), + (fixture_ref("segmentation_inference_png"), NDMask), + ], +) def test_tool_construction(inference, expected_type): assert isinstance(NDTool.build(inference[0]), expected_type) def no_tool(text_inference, module_project): pred = text_inference[0].copy() - #Missing key - del pred['answer'] + # Missing key + del pred["answer"] with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) -@pytest.mark.parametrize( - "configured_project", - [MediaType.Text], - indirect=True -) + +@pytest.mark.parametrize("configured_project", [MediaType.Text], indirect=True) def test_invalid_text(text_inference, configured_project): - #and if it is not a string + # and if it is not a string pred = text_inference[0].copy() - #Extra and wrong key - del pred['answer'] - pred['answers'] = [] + # Extra and wrong key + del pred["answer"] + pred["answers"] = [] with pytest.raises(MALValidationError): _validate_ndjson([pred], configured_project) - del pred['answers'] + del pred["answers"] - #Invalid type - pred['answer'] = [] + # Invalid type + pred["answer"] = [] with pytest.raises(MALValidationError): _validate_ndjson([pred], configured_project) - #Invalid type - pred['answer'] = None + # Invalid type + pred["answer"] = None with pytest.raises(MALValidationError): _validate_ndjson([pred], configured_project) -def test_invalid_checklist_item(checklist_inference, - module_project): - #Only two points +def test_invalid_checklist_item(checklist_inference, module_project): + # Only two points pred = checklist_inference[0].copy() - pred['answers'] = [pred['answers'][0], pred['answers'][0]] - #Duplicate schema ids + pred["answers"] = [pred["answers"][0], pred["answers"][0]] + # Duplicate schema ids with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) - pred['answers'] = [{"name": "asdfg"}] + pred["answers"] = [{"name": "asdfg"}] with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) - pred['answers'] = [{"schemaId": "1232132132"}] + pred["answers"] = [{"schemaId": "1232132132"}] with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) - pred['answers'] = [{}] + pred["answers"] = [{}] with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) - pred['answers'] = [] + pred["answers"] = [] with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) - del pred['answers'] + del pred["answers"] with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) def test_invalid_polygon(polygon_inference, module_project): - #Only two points + # Only two points pred = polygon_inference[0].copy() - pred['polygon'] = [{"x": 100, "y": 100}, {"x": 200, "y": 200}] + pred["polygon"] = [{"x": 100, "y": 100}, {"x": 200, "y": 200}] with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) -@pytest.mark.parametrize( - "configured_project", - [MediaType.Text], - indirect=True -) +@pytest.mark.parametrize("configured_project", [MediaType.Text], indirect=True) def test_incorrect_entity(entity_inference, configured_project): entity = entity_inference[0].copy() - #Location cannot be a list + # Location cannot be a list entity["location"] = [0, 10] with pytest.raises(MALValidationError): _validate_ndjson([entity], configured_project) @@ -126,53 +132,50 @@ def test_incorrect_entity(entity_inference, configured_project): _validate_ndjson([entity], configured_project) -@pytest.mark.skip("Test wont work/fails randomly since projects have to have a media type and could be missing features from prediction list") +@pytest.mark.skip( + "Test wont work/fails randomly since projects have to have a media type and could be missing features from prediction list" +) def test_all_validate_json(module_project, predictions): - #Predictions contains one of each type of prediction. - #These should be properly formatted and pass. + # Predictions contains one of each type of prediction. + # These should be properly formatted and pass. _validate_ndjson(predictions[0], module_project) def test_incorrect_line(line_inference, module_project): line = line_inference[0].copy() - line["line"] = [line["line"][0]] #Just one point + line["line"] = [line["line"][0]] # Just one point with pytest.raises(MALValidationError): _validate_ndjson([line], module_project) -def test_incorrect_rectangle(rectangle_inference, - module_project): - del rectangle_inference[0]['bbox']['top'] +def test_incorrect_rectangle(rectangle_inference, module_project): + del rectangle_inference[0]["bbox"]["top"] with pytest.raises(MALValidationError): - _validate_ndjson([rectangle_inference], - module_project) + _validate_ndjson([rectangle_inference], module_project) def test_duplicate_tools(rectangle_inference, module_project): pred = rectangle_inference[0].copy() - pred['polygon'] = [{"x": 100, "y": 100}, {"x": 200, "y": 200}] + pred["polygon"] = [{"x": 100, "y": 100}, {"x": 200, "y": 200}] with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) -def test_invalid_feature_schema(module_project, - rectangle_inference): +def test_invalid_feature_schema(module_project, rectangle_inference): pred = rectangle_inference[0].copy() - pred['schemaId'] = "blahblah" + pred["schemaId"] = "blahblah" with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) -def test_name_only_feature_schema(module_project, - rectangle_inference): +def test_name_only_feature_schema(module_project, rectangle_inference): pred = rectangle_inference[0].copy() _validate_ndjson([pred], module_project) -def test_schema_id_only_feature_schema(module_project, - rectangle_inference): +def test_schema_id_only_feature_schema(module_project, rectangle_inference): pred = rectangle_inference[0].copy() - del pred['name'] + del pred["name"] ontology = module_project.ontology().normalized["tools"] for tool in ontology: if tool["name"] == "bbox": @@ -181,10 +184,9 @@ def test_schema_id_only_feature_schema(module_project, _validate_ndjson([pred], module_project) -def test_missing_feature_schema(module_project, - rectangle_inference): +def test_missing_feature_schema(module_project, rectangle_inference): pred = rectangle_inference[0].copy() - del pred['name'] + del pred["name"] with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) @@ -197,31 +199,32 @@ def test_validate_ndjson(tmp_path, configured_project): with pytest.raises(ValueError): configured_project.upload_annotations( - name="name", annotations=str(file_path), validate=True) + name="name", annotations=str(file_path), validate=True + ) -def test_validate_ndjson_uuid(tmp_path, configured_project, - predictions): +def test_validate_ndjson_uuid(tmp_path, configured_project, predictions): file_name = f"repeat_uuid.ndjson" file_path = tmp_path / file_name repeat_uuid = predictions.copy() - repeat_uuid[0]['uuid'] = 'test_uuid' - repeat_uuid[1]['uuid'] = 'test_uuid' + repeat_uuid[0]["uuid"] = "test_uuid" + repeat_uuid[1]["uuid"] = "test_uuid" with file_path.open("w") as f: parser.dump(repeat_uuid, f) with pytest.raises(MALValidationError): configured_project.upload_annotations( - name="name", validate=True, annotations=str(file_path)) + name="name", validate=True, annotations=str(file_path) + ) with pytest.raises(MALValidationError): configured_project.upload_annotations( - name="name", validate=True, annotations=repeat_uuid) + name="name", validate=True, annotations=repeat_uuid + ) @pytest.mark.parametrize("configured_project", [MediaType.Video], indirect=True) -def test_video_upload(video_checklist_inference, - configured_project): +def test_video_upload(video_checklist_inference, configured_project): pred = video_checklist_inference[0].copy() _validate_ndjson([pred], configured_project) diff --git a/libs/labelbox/tests/data/annotation_import/test_send_to_annotate_mea.py b/libs/labelbox/tests/data/annotation_import/test_send_to_annotate_mea.py index 1f8b84742..4bcd4dcef 100644 --- a/libs/labelbox/tests/data/annotation_import/test_send_to_annotate_mea.py +++ b/libs/labelbox/tests/data/annotation_import/test_send_to_annotate_mea.py @@ -1,14 +1,22 @@ import pytest from labelbox import UniqueIds, OntologyBuilder -from labelbox.schema.conflict_resolution_strategy import ConflictResolutionStrategy +from labelbox.schema.conflict_resolution_strategy import ( + ConflictResolutionStrategy, +) -def test_send_to_annotate_from_model(client, configured_project, - model_run_predictions, - model_run_with_data_rows, project): +def test_send_to_annotate_from_model( + client, + configured_project, + model_run_predictions, + model_run_with_data_rows, + project, +): model_run = model_run_with_data_rows - data_row_ids = list(set([p['dataRow']['id'] for p in model_run_predictions])) + data_row_ids = list( + set([p["dataRow"]["id"] for p in model_run_predictions]) + ) assert len(data_row_ids) > 0 destination_project = project @@ -18,22 +26,27 @@ def test_send_to_annotate_from_model(client, configured_project, queues = destination_project.task_queues() initial_review_task = next( - q for q in queues if q.name == "Initial review task") + q for q in queues if q.name == "Initial review task" + ) # build an ontology mapping using the top level tools and classifications source_ontology_builder = OntologyBuilder.from_project(configured_project) feature_schema_ids = list( - tool.feature_schema_id for tool in source_ontology_builder.tools) + tool.feature_schema_id for tool in source_ontology_builder.tools + ) # create a dictionary of feature schema id to itself ontology_mapping = dict(zip(feature_schema_ids, feature_schema_ids)) classification_feature_schema_ids = list( classification.feature_schema_id - for classification in source_ontology_builder.classifications) + for classification in source_ontology_builder.classifications + ) # create a dictionary of feature schema id to itself classification_ontology_mapping = dict( - zip(classification_feature_schema_ids, - classification_feature_schema_ids)) + zip( + classification_feature_schema_ids, classification_feature_schema_ids + ) + ) # combine the two ontology mappings ontology_mapping.update(classification_ontology_mapping) @@ -44,11 +57,10 @@ def test_send_to_annotate_from_model(client, configured_project, data_rows=UniqueIds(data_row_ids), task_queue_id=initial_review_task.uid, params={ - "predictions_ontology_mapping": - ontology_mapping, - "override_existing_annotations_rule": - ConflictResolutionStrategy.OverrideWithPredictions - }) + "predictions_ontology_mapping": ontology_mapping, + "override_existing_annotations_rule": ConflictResolutionStrategy.OverrideWithPredictions, + }, + ) task.wait_till_done() @@ -66,5 +78,5 @@ def test_send_to_annotate_from_model(client, configured_project, assert all([dr in data_row_ids for dr in destination_data_rows]) # Since data rows were added to a review queue, predictions should be imported into the project as labels - destination_project_labels = (list(destination_project.labels())) + destination_project_labels = list(destination_project.labels()) assert len(destination_project_labels) == len(data_row_ids) diff --git a/libs/labelbox/tests/data/annotation_import/test_upsert_prediction_import.py b/libs/labelbox/tests/data/annotation_import/test_upsert_prediction_import.py index 59c894c65..a60e0aa59 100644 --- a/libs/labelbox/tests/data/annotation_import/test_upsert_prediction_import.py +++ b/libs/labelbox/tests/data/annotation_import/test_upsert_prediction_import.py @@ -1,6 +1,7 @@ import uuid from labelbox import parser import pytest + """ - Here we only want to check that the uploads are calling the validation - Then with unit tests we can check the types of errors raised @@ -9,10 +10,14 @@ @pytest.mark.skip() -def test_create_from_url(client, tmp_path, object_predictions, - model_run_with_data_rows, - configured_project, - annotation_import_test_helpers): +def test_create_from_url( + client, + tmp_path, + object_predictions, + model_run_with_data_rows, + configured_project, + annotation_import_test_helpers, +): name = str(uuid.uuid4()) file_name = f"{name}.json" file_path = tmp_path / file_name @@ -22,8 +27,9 @@ def test_create_from_url(client, tmp_path, object_predictions, for mrdr in model_run_with_data_rows.model_run_data_rows() ] predictions = [ - p for p in object_predictions - if p['dataRow']['id'] in model_run_data_rows + p + for p in object_predictions + if p["dataRow"]["id"] in model_run_data_rows ] with file_path.open("w") as f: parser.dump(predictions, f) @@ -31,16 +37,21 @@ def test_create_from_url(client, tmp_path, object_predictions, # Needs to have data row ids with open(file_path, "r") as f: - url = client.upload_data(content=f.read(), - filename=file_name, - sign=True, - content_type="application/json") - - annotation_import, batch, mal_prediction_import = model_run_with_data_rows.upsert_predictions_and_send_to_project( - name=name, - predictions=url, - project_id=configured_project.uid, - priority=5) + url = client.upload_data( + content=f.read(), + filename=file_name, + sign=True, + content_type="application/json", + ) + + annotation_import, batch, mal_prediction_import = ( + model_run_with_data_rows.upsert_predictions_and_send_to_project( + name=name, + predictions=url, + project_id=configured_project.uid, + priority=5, + ) + ) assert annotation_import.model_run_id == model_run_with_data_rows.uid annotation_import.wait_until_done() @@ -58,24 +69,30 @@ def test_create_from_url(client, tmp_path, object_predictions, @pytest.mark.skip() -def test_create_from_objects(model_run_with_data_rows, - configured_project, - object_predictions, - annotation_import_test_helpers): +def test_create_from_objects( + model_run_with_data_rows, + configured_project, + object_predictions, + annotation_import_test_helpers, +): name = str(uuid.uuid4()) model_run_data_rows = [ mrdr.data_row().uid for mrdr in model_run_with_data_rows.model_run_data_rows() ] predictions = [ - p for p in object_predictions - if p['dataRow']['id'] in model_run_data_rows + p + for p in object_predictions + if p["dataRow"]["id"] in model_run_data_rows ] - annotation_import, batch, mal_prediction_import = model_run_with_data_rows.upsert_predictions_and_send_to_project( - name=name, - predictions=predictions, - project_id=configured_project.uid, - priority=5) + annotation_import, batch, mal_prediction_import = ( + model_run_with_data_rows.upsert_predictions_and_send_to_project( + name=name, + predictions=predictions, + project_id=configured_project.uid, + priority=5, + ) + ) assert annotation_import.model_run_id == model_run_with_data_rows.uid annotation_import.wait_until_done() @@ -93,11 +110,13 @@ def test_create_from_objects(model_run_with_data_rows, @pytest.mark.skip() -def test_create_from_local_file(tmp_path, model_run_with_data_rows, - configured_project_with_one_data_row, - object_predictions, - annotation_import_test_helpers): - +def test_create_from_local_file( + tmp_path, + model_run_with_data_rows, + configured_project_with_one_data_row, + object_predictions, + annotation_import_test_helpers, +): name = str(uuid.uuid4()) file_name = f"{name}.ndjson" file_path = tmp_path / file_name @@ -107,18 +126,22 @@ def test_create_from_local_file(tmp_path, model_run_with_data_rows, for mrdr in model_run_with_data_rows.model_run_data_rows() ] predictions = [ - p for p in object_predictions - if p['dataRow']['id'] in model_run_data_rows + p + for p in object_predictions + if p["dataRow"]["id"] in model_run_data_rows ] with file_path.open("w") as f: parser.dump(predictions, f) - annotation_import, batch, mal_prediction_import = model_run_with_data_rows.upsert_predictions_and_send_to_project( - name=name, - predictions=str(file_path), - project_id=configured_project_with_one_data_row.uid, - priority=5) + annotation_import, batch, mal_prediction_import = ( + model_run_with_data_rows.upsert_predictions_and_send_to_project( + name=name, + predictions=str(file_path), + project_id=configured_project_with_one_data_row.uid, + priority=5, + ) + ) assert annotation_import.model_run_id == model_run_with_data_rows.uid annotation_import.wait_until_done() diff --git a/libs/labelbox/tests/data/annotation_types/classification/test_classification.py b/libs/labelbox/tests/data/annotation_types/classification/test_classification.py index 066cf91bd..801cdb232 100644 --- a/libs/labelbox/tests/data/annotation_types/classification/test_classification.py +++ b/libs/labelbox/tests/data/annotation_types/classification/test_classification.py @@ -1,8 +1,12 @@ import pytest -from labelbox.data.annotation_types import (Checklist, ClassificationAnswer, - Radio, Text, - ClassificationAnnotation) +from labelbox.data.annotation_types import ( + Checklist, + ClassificationAnswer, + Radio, + Text, + ClassificationAnnotation, +) from pydantic import ValidationError @@ -14,18 +18,21 @@ def test_classification_answer(): feature_schema_id = "immunoelectrophoretically" name = "my_feature" confidence = 0.9 - custom_metrics = [{'name': 'metric1', 'value': 2}] - answer = ClassificationAnswer(name=name, - confidence=confidence, - custom_metrics=custom_metrics) + custom_metrics = [{"name": "metric1", "value": 2}] + answer = ClassificationAnswer( + name=name, confidence=confidence, custom_metrics=custom_metrics + ) assert answer.feature_schema_id is None assert answer.name == name assert answer.confidence == confidence - assert [answer.custom_metrics[0].model_dump(exclude_none=True)] == custom_metrics + assert [ + answer.custom_metrics[0].model_dump(exclude_none=True) + ] == custom_metrics - answer = ClassificationAnswer(feature_schema_id=feature_schema_id, - name=name) + answer = ClassificationAnswer( + feature_schema_id=feature_schema_id, name=name + ) assert answer.feature_schema_id == feature_schema_id assert answer.name == name @@ -33,9 +40,13 @@ def test_classification_answer(): def test_classification(): answer = "1234" - classification = ClassificationAnnotation(value=Text(answer=answer), - name="a classification") - assert classification.model_dump(exclude_none=True)['value']['answer'] == answer + classification = ClassificationAnnotation( + value=Text(answer=answer), name="a classification" + ) + assert ( + classification.model_dump(exclude_none=True)["value"]["answer"] + == answer + ) with pytest.raises(ValidationError): ClassificationAnnotation() @@ -48,107 +59,98 @@ def test_subclass(): with pytest.raises(ValidationError): # Should have feature schema info classification = ClassificationAnnotation(value=Text(answer=answer)) - classification = ClassificationAnnotation(value=Text(answer=answer), - name=name) + classification = ClassificationAnnotation( + value=Text(answer=answer), name=name + ) assert classification.model_dump(exclude_none=True) == { - 'name': name, - 'extra': {}, - 'value': { - 'answer': answer, + "name": name, + "extra": {}, + "value": { + "answer": answer, }, } classification = ClassificationAnnotation( value=Text(answer=answer), name=name, - feature_schema_id=feature_schema_id) + feature_schema_id=feature_schema_id, + ) assert classification.model_dump(exclude_none=True) == { - 'feature_schema_id': feature_schema_id, - 'extra': {}, - 'value': { - 'answer': answer, + "feature_schema_id": feature_schema_id, + "extra": {}, + "value": { + "answer": answer, }, - 'name': name, + "name": name, } classification = ClassificationAnnotation( value=Text(answer=answer), feature_schema_id=feature_schema_id, - name=name) + name=name, + ) assert classification.model_dump(exclude_none=True) == { - 'name': name, - 'feature_schema_id': feature_schema_id, - 'extra': {}, - 'value': { - 'answer': answer, + "name": name, + "feature_schema_id": feature_schema_id, + "extra": {}, + "value": { + "answer": answer, }, } def test_radio(): - answer = ClassificationAnswer(name="1", - confidence=0.81, - custom_metrics=[{ - 'name': 'metric1', - 'value': 0.99 - }]) + answer = ClassificationAnswer( + name="1", + confidence=0.81, + custom_metrics=[{"name": "metric1", "value": 0.99}], + ) feature_schema_id = "immunoelectrophoretically" name = "my_feature" with pytest.raises(ValidationError): - classification = ClassificationAnnotation(value=Radio( - answer=answer.name)) + classification = ClassificationAnnotation( + value=Radio(answer=answer.name) + ) with pytest.raises(ValidationError): classification = Radio(answer=[answer]) classification = Radio(answer=answer) assert classification.model_dump(exclude_none=True) == { - 'answer': { - 'name': answer.name, - 'extra': {}, - 'confidence': 0.81, - 'custom_metrics': [{ - 'name': 'metric1', - 'value': 0.99 - }], + "answer": { + "name": answer.name, + "extra": {}, + "confidence": 0.81, + "custom_metrics": [{"name": "metric1", "value": 0.99}], } } classification = ClassificationAnnotation( value=Radio(answer=answer), feature_schema_id=feature_schema_id, name=name, - custom_metrics=[{ - 'name': 'metric1', - 'value': 0.99 - }]) + custom_metrics=[{"name": "metric1", "value": 0.99}], + ) assert classification.model_dump(exclude_none=True) == { - 'name': name, - 'feature_schema_id': feature_schema_id, - 'extra': {}, - 'custom_metrics': [{ - 'name': 'metric1', - 'value': 0.99 - }], - 'value': { - 'answer': { - 'name': answer.name, - 'extra': {}, - 'confidence': 0.81, - 'custom_metrics': [{ - 'name': 'metric1', - 'value': 0.99 - }] + "name": name, + "feature_schema_id": feature_schema_id, + "extra": {}, + "custom_metrics": [{"name": "metric1", "value": 0.99}], + "value": { + "answer": { + "name": answer.name, + "extra": {}, + "confidence": 0.81, + "custom_metrics": [{"name": "metric1", "value": 0.99}], }, }, } def test_checklist(): - answer = ClassificationAnswer(name="1", - confidence=0.99, - custom_metrics=[{ - 'name': 'metric1', - 'value': 2 - }]) + answer = ClassificationAnswer( + name="1", + confidence=0.99, + custom_metrics=[{"name": "metric1", "value": 2}], + ) feature_schema_id = "immunoelectrophoretically" name = "my_feature" @@ -160,15 +162,14 @@ def test_checklist(): classification = Checklist(answer=[answer]) assert classification.model_dump(exclude_none=True) == { - 'answer': [{ - 'name': answer.name, - 'extra': {}, - 'confidence': 0.99, - 'custom_metrics': [{ - 'name': 'metric1', - 'value': 2 - }], - }] + "answer": [ + { + "name": answer.name, + "extra": {}, + "confidence": 0.99, + "custom_metrics": [{"name": "metric1", "value": 2}], + } + ] } classification = ClassificationAnnotation( value=Checklist(answer=[answer]), @@ -176,18 +177,17 @@ def test_checklist(): name=name, ) assert classification.model_dump(exclude_none=True) == { - 'name': name, - 'feature_schema_id': feature_schema_id, - 'extra': {}, - 'value': { - 'answer': [{ - 'name': answer.name, - 'extra': {}, - 'confidence': 0.99, - 'custom_metrics': [{ - 'name': 'metric1', - 'value': 2 - }], - }] + "name": name, + "feature_schema_id": feature_schema_id, + "extra": {}, + "value": { + "answer": [ + { + "name": answer.name, + "extra": {}, + "confidence": 0.99, + "custom_metrics": [{"name": "metric1", "value": 2}], + } + ] }, } diff --git a/libs/labelbox/tests/data/annotation_types/data/test_raster.py b/libs/labelbox/tests/data/annotation_types/data/test_raster.py index 4ce787022..6bc8f2bbf 100644 --- a/libs/labelbox/tests/data/annotation_types/data/test_raster.py +++ b/libs/labelbox/tests/data/annotation_types/data/test_raster.py @@ -42,11 +42,13 @@ def test_ref(): uid = "uid" metadata = [] media_attributes = {} - data = ImageData(im_bytes=b'', - external_id=external_id, - uid=uid, - metadata=metadata, - media_attributes=media_attributes) + data = ImageData( + im_bytes=b"", + external_id=external_id, + uid=uid, + metadata=metadata, + media_attributes=media_attributes, + ) assert data.external_id == external_id assert data.uid == uid assert data.media_attributes == media_attributes diff --git a/libs/labelbox/tests/data/annotation_types/data/test_text.py b/libs/labelbox/tests/data/annotation_types/data/test_text.py index 0af0a37fb..865f93e65 100644 --- a/libs/labelbox/tests/data/annotation_types/data/test_text.py +++ b/libs/labelbox/tests/data/annotation_types/data/test_text.py @@ -15,9 +15,9 @@ def test_text(): text = "hello world" metadata = [] media_attributes = {} - text_data = TextData(text=text, - metadata=metadata, - media_attributes=media_attributes) + text_data = TextData( + text=text, metadata=metadata, media_attributes=media_attributes + ) assert text_data.text == text @@ -31,7 +31,7 @@ def test_url(): def test_file(tmpdir): content = "foo bar baz" file = "hello.txt" - dir = tmpdir.mkdir('data') + dir = tmpdir.mkdir("data") dir.join(file).write(content) text_data = TextData(file_path=os.path.join(dir.strpath, file)) assert len(text_data.value) == len(content) @@ -42,11 +42,13 @@ def test_ref(): uid = "uid" metadata = [] media_attributes = {} - data = TextData(text="hello world", - external_id=external_id, - uid=uid, - metadata=metadata, - media_attributes=media_attributes) + data = TextData( + text="hello world", + external_id=external_id, + uid=uid, + metadata=metadata, + media_attributes=media_attributes, + ) assert data.external_id == external_id assert data.uid == uid assert data.media_attributes == media_attributes diff --git a/libs/labelbox/tests/data/annotation_types/data/test_video.py b/libs/labelbox/tests/data/annotation_types/data/test_video.py index d0e5ed012..5fd77c2c8 100644 --- a/libs/labelbox/tests/data/annotation_types/data/test_video.py +++ b/libs/labelbox/tests/data/annotation_types/data/test_video.py @@ -22,7 +22,7 @@ def test_frames(): def test_file_path(): - path = 'tests/integration/media/cat.mp4' + path = "tests/integration/media/cat.mp4" raster_data = VideoData(file_path=path) with pytest.raises(ValueError): @@ -60,11 +60,13 @@ def test_ref(): } metadata = [] media_attributes = {} - data = VideoData(frames=data, - external_id=external_id, - uid=uid, - metadata=metadata, - media_attributes=media_attributes) + data = VideoData( + frames=data, + external_id=external_id, + uid=uid, + metadata=metadata, + media_attributes=media_attributes, + ) assert data.external_id == external_id assert data.uid == uid assert data.media_attributes == media_attributes diff --git a/libs/labelbox/tests/data/annotation_types/geometry/test_line.py b/libs/labelbox/tests/data/annotation_types/geometry/test_line.py index 10362e728..d6fd1108c 100644 --- a/libs/labelbox/tests/data/annotation_types/geometry/test_line.py +++ b/libs/labelbox/tests/data/annotation_types/geometry/test_line.py @@ -16,7 +16,7 @@ def test_line(): expected = {"coordinates": [points], "type": "MultiLineString"} line = Line(points=[Point(x=x, y=y) for x, y in points]) assert line.geometry == expected - expected['coordinates'] = tuple([tuple([tuple(x) for x in points])]) + expected["coordinates"] = tuple([tuple([tuple(x) for x in points])]) assert line.shapely.__geo_interface__ == expected raster = line.draw(height=32, width=32, thickness=1) diff --git a/libs/labelbox/tests/data/annotation_types/geometry/test_mask.py b/libs/labelbox/tests/data/annotation_types/geometry/test_mask.py index 960e64d9a..6fe8422cf 100644 --- a/libs/labelbox/tests/data/annotation_types/geometry/test_mask.py +++ b/libs/labelbox/tests/data/annotation_types/geometry/test_mask.py @@ -19,45 +19,114 @@ def test_mask(): mask1 = Mask(mask=mask_data, color=(255, 255, 255)) expected1 = { - 'type': - 'MultiPolygon', - 'coordinates': [ - (((0.0, 0.0), (0.0, 1.0), (0.0, 2.0), (0.0, 3.0), (0.0, 4.0), (0.0, - 5.0), - (0.0, 6.0), (0.0, 7.0), (0.0, 8.0), (0.0, 9.0), (0.0, 10.0), - (1.0, 10.0), (2.0, 10.0), (3.0, 10.0), (4.0, 10.0), (5.0, 10.0), - (6.0, 10.0), (7.0, 10.0), (8.0, 10.0), (9.0, 10.0), (10.0, 10.0), - (10.0, 9.0), (10.0, 8.0), (10.0, 7.0), (10.0, 6.0), (10.0, 5.0), - (10.0, 4.0), (10.0, 3.0), (10.0, 2.0), (10.0, 1.0), (10.0, 0.0), - (9.0, 0.0), (8.0, 0.0), (7.0, 0.0), (6.0, 0.0), (5.0, 0.0), - (4.0, 0.0), (3.0, 0.0), (2.0, 0.0), (1.0, 0.0), (0.0, 0.0)),) - ] + "type": "MultiPolygon", + "coordinates": [ + ( + ( + (0.0, 0.0), + (0.0, 1.0), + (0.0, 2.0), + (0.0, 3.0), + (0.0, 4.0), + (0.0, 5.0), + (0.0, 6.0), + (0.0, 7.0), + (0.0, 8.0), + (0.0, 9.0), + (0.0, 10.0), + (1.0, 10.0), + (2.0, 10.0), + (3.0, 10.0), + (4.0, 10.0), + (5.0, 10.0), + (6.0, 10.0), + (7.0, 10.0), + (8.0, 10.0), + (9.0, 10.0), + (10.0, 10.0), + (10.0, 9.0), + (10.0, 8.0), + (10.0, 7.0), + (10.0, 6.0), + (10.0, 5.0), + (10.0, 4.0), + (10.0, 3.0), + (10.0, 2.0), + (10.0, 1.0), + (10.0, 0.0), + (9.0, 0.0), + (8.0, 0.0), + (7.0, 0.0), + (6.0, 0.0), + (5.0, 0.0), + (4.0, 0.0), + (3.0, 0.0), + (2.0, 0.0), + (1.0, 0.0), + (0.0, 0.0), + ), + ) + ], } assert mask1.geometry == expected1 assert mask1.shapely.__geo_interface__ == expected1 mask2 = Mask(mask=mask_data, color=(0, 255, 255)) expected2 = { - 'type': - 'MultiPolygon', - 'coordinates': [ - (((20.0, 20.0), (20.0, 21.0), (20.0, 22.0), (20.0, 23.0), - (20.0, 24.0), (20.0, 25.0), (20.0, 26.0), (20.0, 27.0), - (20.0, 28.0), (20.0, 29.0), (20.0, 30.0), (21.0, 30.0), - (22.0, 30.0), (23.0, 30.0), (24.0, 30.0), (25.0, 30.0), - (26.0, 30.0), (27.0, 30.0), (28.0, 30.0), (29.0, 30.0), - (30.0, 30.0), (30.0, 29.0), (30.0, 28.0), (30.0, 27.0), - (30.0, 26.0), (30.0, 25.0), (30.0, 24.0), (30.0, 23.0), - (30.0, 22.0), (30.0, 21.0), (30.0, 20.0), (29.0, 20.0), - (28.0, 20.0), (27.0, 20.0), (26.0, 20.0), (25.0, 20.0), - (24.0, 20.0), (23.0, 20.0), (22.0, 20.0), (21.0, 20.0), (20.0, - 20.0)),) - ] + "type": "MultiPolygon", + "coordinates": [ + ( + ( + (20.0, 20.0), + (20.0, 21.0), + (20.0, 22.0), + (20.0, 23.0), + (20.0, 24.0), + (20.0, 25.0), + (20.0, 26.0), + (20.0, 27.0), + (20.0, 28.0), + (20.0, 29.0), + (20.0, 30.0), + (21.0, 30.0), + (22.0, 30.0), + (23.0, 30.0), + (24.0, 30.0), + (25.0, 30.0), + (26.0, 30.0), + (27.0, 30.0), + (28.0, 30.0), + (29.0, 30.0), + (30.0, 30.0), + (30.0, 29.0), + (30.0, 28.0), + (30.0, 27.0), + (30.0, 26.0), + (30.0, 25.0), + (30.0, 24.0), + (30.0, 23.0), + (30.0, 22.0), + (30.0, 21.0), + (30.0, 20.0), + (29.0, 20.0), + (28.0, 20.0), + (27.0, 20.0), + (26.0, 20.0), + (25.0, 20.0), + (24.0, 20.0), + (23.0, 20.0), + (22.0, 20.0), + (21.0, 20.0), + (20.0, 20.0), + ), + ) + ], } assert mask2.geometry == expected2 assert mask2.shapely.__geo_interface__ == expected2 - gt_mask = cv2.cvtColor(cv2.imread("tests/data/assets/mask.png"), - cv2.COLOR_BGR2RGB) + gt_mask = cv2.cvtColor( + cv2.imread("tests/data/assets/mask.png"), cv2.COLOR_BGR2RGB + ) assert (gt_mask == mask1.mask.arr).all() assert (gt_mask == mask2.mask.arr).all() @@ -66,13 +135,11 @@ def test_mask(): assert (raster1 != raster2).any() - gt1 = Rectangle(start=Point(x=0, y=0), - end=Point(x=10, y=10)).draw(height=raster1.shape[0], - width=raster1.shape[1], - color=(255, 255, 255)) - gt2 = Rectangle(start=Point(x=20, y=20), - end=Point(x=30, y=30)).draw(height=raster2.shape[0], - width=raster2.shape[1], - color=(0, 255, 255)) + gt1 = Rectangle(start=Point(x=0, y=0), end=Point(x=10, y=10)).draw( + height=raster1.shape[0], width=raster1.shape[1], color=(255, 255, 255) + ) + gt2 = Rectangle(start=Point(x=20, y=20), end=Point(x=30, y=30)).draw( + height=raster2.shape[0], width=raster2.shape[1], color=(0, 255, 255) + ) assert (raster1 == gt1).all() assert (raster2 == gt2).all() diff --git a/libs/labelbox/tests/data/annotation_types/geometry/test_point.py b/libs/labelbox/tests/data/annotation_types/geometry/test_point.py index bca3900d2..335fb6a3a 100644 --- a/libs/labelbox/tests/data/annotation_types/geometry/test_point.py +++ b/libs/labelbox/tests/data/annotation_types/geometry/test_point.py @@ -15,7 +15,7 @@ def test_point(): point = Point(x=0, y=1) expected = {"coordinates": [0, 1], "type": "Point"} assert point.geometry == expected - expected['coordinates'] = tuple(expected['coordinates']) + expected["coordinates"] = tuple(expected["coordinates"]) assert point.shapely.__geo_interface__ == expected raster = point.draw(height=32, width=32, thickness=1) diff --git a/libs/labelbox/tests/data/annotation_types/geometry/test_polygon.py b/libs/labelbox/tests/data/annotation_types/geometry/test_polygon.py index 084349023..0a0bb49b0 100644 --- a/libs/labelbox/tests/data/annotation_types/geometry/test_polygon.py +++ b/libs/labelbox/tests/data/annotation_types/geometry/test_polygon.py @@ -4,6 +4,7 @@ from labelbox.data.annotation_types import Polygon, Point from pydantic import ValidationError + def test_polygon(): with pytest.raises(ValidationError): polygon = Polygon() @@ -14,12 +15,13 @@ def test_polygon(): with pytest.raises(ValidationError): polygon = Polygon(points=[Point(x=0, y=1), Point(x=0, y=1)]) - points = [[0., 1.], [0., 2.], [2., 2.], [2., 0.]] + points = [[0.0, 1.0], [0.0, 2.0], [2.0, 2.0], [2.0, 0.0]] expected = {"coordinates": [points + [points[0]]], "type": "Polygon"} polygon = Polygon(points=[Point(x=x, y=y) for x, y in points]) assert polygon.geometry == expected - expected['coordinates'] = tuple( - [tuple([tuple(x) for x in points + [points[0]]])]) + expected["coordinates"] = tuple( + [tuple([tuple(x) for x in points + [points[0]]])] + ) assert polygon.shapely.__geo_interface__ == expected raster = polygon.draw(10, 10) diff --git a/libs/labelbox/tests/data/annotation_types/geometry/test_rectangle.py b/libs/labelbox/tests/data/annotation_types/geometry/test_rectangle.py index d1d7331d6..54f85eed8 100644 --- a/libs/labelbox/tests/data/annotation_types/geometry/test_rectangle.py +++ b/libs/labelbox/tests/data/annotation_types/geometry/test_rectangle.py @@ -13,12 +13,12 @@ def test_rectangle(): points = [[[0.0, 1.0], [0.0, 10.0], [10.0, 10.0], [10.0, 1.0], [0.0, 1.0]]] expected = {"coordinates": points, "type": "Polygon"} assert rectangle.geometry == expected - expected['coordinates'] = tuple([tuple([tuple(x) for x in points[0]])]) + expected["coordinates"] = tuple([tuple([tuple(x) for x in points[0]])]) assert rectangle.shapely.__geo_interface__ == expected raster = rectangle.draw(height=32, width=32) assert (cv2.imread("tests/data/assets/rectangle.png") == raster).all() - xyhw = Rectangle.from_xyhw(0., 0, 10, 10) - assert xyhw.start == Point(x=0, y=0.) + xyhw = Rectangle.from_xyhw(0.0, 0, 10, 10) + assert xyhw.start == Point(x=0, y=0.0) assert xyhw.end == Point(x=10, y=10.0) diff --git a/libs/labelbox/tests/data/annotation_types/test_annotation.py b/libs/labelbox/tests/data/annotation_types/test_annotation.py index 926d8bc97..8cdeac9ba 100644 --- a/libs/labelbox/tests/data/annotation_types/test_annotation.py +++ b/libs/labelbox/tests/data/annotation_types/test_annotation.py @@ -1,8 +1,13 @@ import pytest -from labelbox.data.annotation_types import (Text, Point, Line, - ClassificationAnnotation, - ObjectAnnotation, TextEntity) +from labelbox.data.annotation_types import ( + Text, + Point, + Line, + ClassificationAnnotation, + ObjectAnnotation, + TextEntity, +) from labelbox.data.annotation_types.video import VideoObjectAnnotation from labelbox.data.annotation_types.geometry.rectangle import Rectangle from labelbox.data.annotation_types.video import VideoClassificationAnnotation @@ -19,7 +24,11 @@ def test_annotation(): value=line, name=name, ) - assert annotation.value.points[0].model_dump() == {'extra': {}, 'x': 1., 'y': 2.} + assert annotation.value.points[0].model_dump() == { + "extra": {}, + "x": 1.0, + "y": 2.0, + } assert annotation.name == name # Check ner @@ -68,25 +77,27 @@ def test_video_annotations(): def test_confidence_for_video_is_not_supported(): with pytest.raises(ConfidenceNotSupportedException): - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=26.5), - end=Point(extra={}, - x=561.0, - y=348.0)), - classifications=[], - frame=24, - keyframe=False, - confidence=0.3434), + ( + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=26.5), + end=Point(extra={}, x=561.0, y=348.0), + ), + classifications=[], + frame=24, + keyframe=False, + confidence=0.3434, + ), + ) def test_confidence_value_range_validation(): diff --git a/libs/labelbox/tests/data/annotation_types/test_collection.py b/libs/labelbox/tests/data/annotation_types/test_collection.py index 34b868162..16c9d699f 100644 --- a/libs/labelbox/tests/data/annotation_types/test_collection.py +++ b/libs/labelbox/tests/data/annotation_types/test_collection.py @@ -4,9 +4,16 @@ import numpy as np import pytest -from labelbox.data.annotation_types import (LabelGenerator, ObjectAnnotation, - ImageData, MaskData, Line, Mask, - Point, Label) +from labelbox.data.annotation_types import ( + LabelGenerator, + ObjectAnnotation, + ImageData, + MaskData, + Line, + Mask, + Point, + Label, +) from labelbox import OntologyBuilder, Tool @@ -17,7 +24,6 @@ def list_of_labels(): @pytest.fixture def signer(): - def get_signer(uuid): return lambda x: uuid @@ -25,7 +31,6 @@ def get_signer(uuid): class FakeDataset: - def __init__(self): self.uid = "ckrb4tgm51xl10ybc7lv9ghm7" self.exports = [] @@ -38,9 +43,12 @@ def create_data_row(self, row_data, external_id=None): def create_data_rows(self, args): for arg in args: self.exports.append( - SimpleNamespace(row_data=arg['row_data'], - external_id=arg['external_id'], - uid=self.uid)) + SimpleNamespace( + row_data=arg["row_data"], + external_id=arg["external_id"], + uid=self.uid, + ) + ) return self def wait_till_done(self): @@ -72,23 +80,26 @@ def test_adding_schema_ids(): data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), annotations=[ ObjectAnnotation( - value=Line( - points=[Point(x=1, y=2), Point(x=2, y=2)]), + value=Line(points=[Point(x=1, y=2), Point(x=2, y=2)]), name=name, ) - ]) + ], + ) feature_schema_id = "expected_id" - ontology = OntologyBuilder(tools=[ - Tool(Tool.Type.LINE, name=name, feature_schema_id=feature_schema_id) - ]) + ontology = OntologyBuilder( + tools=[ + Tool(Tool.Type.LINE, name=name, feature_schema_id=feature_schema_id) + ] + ) generator = LabelGenerator([label]).assign_feature_schema_ids(ontology) assert next(generator).annotations[0].feature_schema_id == feature_schema_id def test_adding_urls(signer): - label = Label(data=ImageData(arr=np.random.random((32, 32, - 3)).astype(np.uint8)), - annotations=[]) + label = Label( + data=ImageData(arr=np.random.random((32, 32, 3)).astype(np.uint8)), + annotations=[], + ) uuid = str(uuid4()) generator = LabelGenerator([label]).add_url_to_data(signer(uuid)) assert label.data.url != uuid @@ -98,9 +109,10 @@ def test_adding_urls(signer): def test_adding_to_dataset(signer): dataset = FakeDataset() - label = Label(data=ImageData(arr=np.random.random((32, 32, - 3)).astype(np.uint8)), - annotations=[]) + label = Label( + data=ImageData(arr=np.random.random((32, 32, 3)).astype(np.uint8)), + annotations=[], + ) uuid = str(uuid4()) generator = LabelGenerator([label]).add_to_dataset(dataset, signer(uuid)) assert label.data.url != uuid @@ -115,12 +127,17 @@ def test_adding_to_masks(signer): label = Label( data=ImageData(arr=np.random.random((32, 32, 3)).astype(np.uint8)), annotations=[ - ObjectAnnotation(name="1234", - value=Mask(mask=MaskData( - arr=np.random.random((32, 32, - 3)).astype(np.uint8)), - color=[255, 255, 255])) - ]) + ObjectAnnotation( + name="1234", + value=Mask( + mask=MaskData( + arr=np.random.random((32, 32, 3)).astype(np.uint8) + ), + color=[255, 255, 255], + ), + ) + ], + ) uuid = str(uuid4()) generator = LabelGenerator([label]).add_url_to_masks(signer(uuid)) assert label.annotations[0].value.mask.url != uuid diff --git a/libs/labelbox/tests/data/annotation_types/test_label.py b/libs/labelbox/tests/data/annotation_types/test_label.py index f0957fcee..5bdfb6bde 100644 --- a/libs/labelbox/tests/data/annotation_types/test_label.py +++ b/libs/labelbox/tests/data/annotation_types/test_label.py @@ -2,12 +2,24 @@ import numpy as np import labelbox.types as lb_types -from labelbox import OntologyBuilder, Tool, Classification as OClassification, Option -from labelbox.data.annotation_types import (ClassificationAnswer, Radio, Text, - ClassificationAnnotation, - PromptText, - ObjectAnnotation, Point, Line, - ImageData, Label) +from labelbox import ( + OntologyBuilder, + Tool, + Classification as OClassification, + Option, +) +from labelbox.data.annotation_types import ( + ClassificationAnswer, + Radio, + Text, + ClassificationAnnotation, + PromptText, + ObjectAnnotation, + Point, + Line, + ImageData, + Label, +) import pytest @@ -17,15 +29,17 @@ def test_schema_assignment_geometry(): data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), annotations=[ ObjectAnnotation( - value=Line( - points=[Point(x=1, y=2), Point(x=2, y=2)]), + value=Line(points=[Point(x=1, y=2), Point(x=2, y=2)]), name=name, ) - ]) + ], + ) feature_schema_id = "expected_id" - ontology = OntologyBuilder(tools=[ - Tool(Tool.Type.LINE, name=name, feature_schema_id=feature_schema_id) - ]) + ontology = OntologyBuilder( + tools=[ + Tool(Tool.Type.LINE, name=name, feature_schema_id=feature_schema_id) + ] + ) label.assign_feature_schema_ids(ontology) assert label.annotations[0].feature_schema_id == feature_schema_id @@ -36,38 +50,47 @@ def test_schema_assignment_classification(): text_name = "text_name" option_name = "my_option" - label = Label(data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), - annotations=[ - ClassificationAnnotation(value=Radio( - answer=ClassificationAnswer(name=option_name)), - name=radio_name), - ClassificationAnnotation(value=Text(answer="some text"), - name=text_name) - ]) + label = Label( + data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), + annotations=[ + ClassificationAnnotation( + value=Radio(answer=ClassificationAnswer(name=option_name)), + name=radio_name, + ), + ClassificationAnnotation( + value=Text(answer="some text"), name=text_name + ), + ], + ) radio_schema_id = "radio_schema_id" text_schema_id = "text_schema_id" option_schema_id = "option_schema_id" ontology = OntologyBuilder( tools=[], classifications=[ - OClassification(class_type=OClassification.Type.RADIO, - name=radio_name, - feature_schema_id=radio_schema_id, - options=[ - Option(value=option_name, - feature_schema_id=option_schema_id) - ]), + OClassification( + class_type=OClassification.Type.RADIO, + name=radio_name, + feature_schema_id=radio_schema_id, + options=[ + Option( + value=option_name, feature_schema_id=option_schema_id + ) + ], + ), OClassification( class_type=OClassification.Type.TEXT, name=text_name, feature_schema_id=text_schema_id, - ) - ]) + ), + ], + ) label.assign_feature_schema_ids(ontology) assert label.annotations[0].feature_schema_id == radio_schema_id assert label.annotations[1].feature_schema_id == text_schema_id - assert label.annotations[ - 0].value.answer.feature_schema_id == option_schema_id + assert ( + label.annotations[0].value.answer.feature_schema_id == option_schema_id + ) def test_schema_assignment_subclass(): @@ -81,34 +104,48 @@ def test_schema_assignment_subclass(): label = Label( data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), annotations=[ - ObjectAnnotation(value=Line( - points=[Point(x=1, y=2), Point(x=2, y=2)]), - name=name, - classifications=[classification]) - ]) + ObjectAnnotation( + value=Line(points=[Point(x=1, y=2), Point(x=2, y=2)]), + name=name, + classifications=[classification], + ) + ], + ) feature_schema_id = "expected_id" classification_schema_id = "classification_id" option_schema_id = "option_schema_id" - ontology = OntologyBuilder(tools=[ - Tool(Tool.Type.LINE, - name=name, - feature_schema_id=feature_schema_id, - classifications=[ - OClassification(class_type=OClassification.Type.RADIO, - name=radio_name, - feature_schema_id=classification_schema_id, - options=[ - Option(value=option_name, - feature_schema_id=option_schema_id) - ]) - ]) - ]) + ontology = OntologyBuilder( + tools=[ + Tool( + Tool.Type.LINE, + name=name, + feature_schema_id=feature_schema_id, + classifications=[ + OClassification( + class_type=OClassification.Type.RADIO, + name=radio_name, + feature_schema_id=classification_schema_id, + options=[ + Option( + value=option_name, + feature_schema_id=option_schema_id, + ) + ], + ) + ], + ) + ] + ) label.assign_feature_schema_ids(ontology) assert label.annotations[0].feature_schema_id == feature_schema_id - assert label.annotations[0].classifications[ - 0].feature_schema_id == classification_schema_id - assert label.annotations[0].classifications[ - 0].value.answer.feature_schema_id == option_schema_id + assert ( + label.annotations[0].classifications[0].feature_schema_id + == classification_schema_id + ) + assert ( + label.annotations[0].classifications[0].value.answer.feature_schema_id + == option_schema_id + ) def test_highly_nested(): @@ -121,92 +158,117 @@ def test_highly_nested(): name=radio_name, value=Radio(answer=ClassificationAnswer(name=option_name)), classifications=[ - ClassificationAnnotation(value=Radio(answer=ClassificationAnswer( - name=nested_option_name)), - name=nested_name) - ]) + ClassificationAnnotation( + value=Radio( + answer=ClassificationAnswer(name=nested_option_name) + ), + name=nested_name, + ) + ], + ) label = Label( data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), annotations=[ - ObjectAnnotation(value=Line( - points=[Point(x=1, y=2), Point(x=2, y=2)]), - name=name, - classifications=[classification]) - ]) + ObjectAnnotation( + value=Line(points=[Point(x=1, y=2), Point(x=2, y=2)]), + name=name, + classifications=[classification], + ) + ], + ) feature_schema_id = "expected_id" classification_schema_id = "classification_id" nested_classification_schema_id = "nested_classification_schema_id" option_schema_id = "option_schema_id" - ontology = OntologyBuilder(tools=[ - Tool(Tool.Type.LINE, - name=name, - feature_schema_id=feature_schema_id, - classifications=[ - OClassification( - class_type=OClassification.Type.RADIO, - name=radio_name, - feature_schema_id=classification_schema_id, - options=[ - Option(value=option_name, + ontology = OntologyBuilder( + tools=[ + Tool( + Tool.Type.LINE, + name=name, + feature_schema_id=feature_schema_id, + classifications=[ + OClassification( + class_type=OClassification.Type.RADIO, + name=radio_name, + feature_schema_id=classification_schema_id, + options=[ + Option( + value=option_name, feature_schema_id=option_schema_id, options=[ OClassification( class_type=OClassification.Type.RADIO, name=nested_name, - feature_schema_id= - nested_classification_schema_id, + feature_schema_id=nested_classification_schema_id, options=[ Option( value=nested_option_name, - feature_schema_id= - nested_classification_schema_id) - ]) - ]) - ]) - ]) - ]) + feature_schema_id=nested_classification_schema_id, + ) + ], + ) + ], + ) + ], + ) + ], + ) + ] + ) label.assign_feature_schema_ids(ontology) assert label.annotations[0].feature_schema_id == feature_schema_id - assert label.annotations[0].classifications[ - 0].feature_schema_id == classification_schema_id - assert label.annotations[0].classifications[ - 0].value.answer.feature_schema_id == option_schema_id + assert ( + label.annotations[0].classifications[0].feature_schema_id + == classification_schema_id + ) + assert ( + label.annotations[0].classifications[0].value.answer.feature_schema_id + == option_schema_id + ) def test_schema_assignment_confidence(): name = "line_feature" - label = Label(data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), - annotations=[ - ObjectAnnotation(value=Line( - points=[Point(x=1, y=2), - Point(x=2, y=2)],), - name=name, - confidence=0.914) - ]) + label = Label( + data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), + annotations=[ + ObjectAnnotation( + value=Line( + points=[Point(x=1, y=2), Point(x=2, y=2)], + ), + name=name, + confidence=0.914, + ) + ], + ) assert label.annotations[0].confidence == 0.914 def test_initialize_label_no_coercion(): - global_key = 'global-key' + global_key = "global-key" ner_annotation = lb_types.ObjectAnnotation( name="ner", - value=lb_types.ConversationEntity(start=0, end=8, message_id="4")) - label = Label(data=lb_types.ConversationData(global_key=global_key), - annotations=[ner_annotation]) + value=lb_types.ConversationEntity(start=0, end=8, message_id="4"), + ) + label = Label( + data=lb_types.ConversationData(global_key=global_key), + annotations=[ner_annotation], + ) assert isinstance(label.data, lb_types.ConversationData) assert label.data.global_key == global_key + def test_prompt_classification_validation(): - global_key = 'global-key' + global_key = "global-key" prompt_text = lb_types.PromptClassificationAnnotation( - name="prompt text", - value=PromptText(answer="test") + name="prompt text", value=PromptText(answer="test") ) prompt_text_2 = lb_types.PromptClassificationAnnotation( - name="prompt text", - value=PromptText(answer="test") + name="prompt text", value=PromptText(answer="test") ) with pytest.raises(TypeError) as e_info: - label = Label(data={"global_key": global_key}, - annotations=[prompt_text, prompt_text_2]) + label = Label( + data={"global_key": global_key}, + annotations=[prompt_text, prompt_text_2], + ) diff --git a/libs/labelbox/tests/data/annotation_types/test_metrics.py b/libs/labelbox/tests/data/annotation_types/test_metrics.py index d2e488109..94c9521a5 100644 --- a/libs/labelbox/tests/data/annotation_types/test_metrics.py +++ b/libs/labelbox/tests/data/annotation_types/test_metrics.py @@ -1,7 +1,13 @@ import pytest -from labelbox.data.annotation_types.metrics import ConfusionMatrixAggregation, ScalarMetricAggregation -from labelbox.data.annotation_types.metrics import ConfusionMatrixMetric, ScalarMetric +from labelbox.data.annotation_types.metrics import ( + ConfusionMatrixAggregation, + ScalarMetricAggregation, +) +from labelbox.data.annotation_types.metrics import ( + ConfusionMatrixMetric, + ScalarMetric, +) from labelbox.data.annotation_types import ScalarMetric, Label, ImageData from labelbox.data.annotation_types.metrics.scalar import RESERVED_METRIC_NAMES from pydantic import ValidationError @@ -12,19 +18,22 @@ def test_legacy_scalar_metric(): metric = ScalarMetric(value=value) assert metric.value == value - label = Label(data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), - annotations=[metric]) + label = Label( + data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), annotations=[metric] + ) expected = { - 'data': { - 'uid': 'ckrmd9q8g000009mg6vej7hzg', + "data": { + "uid": "ckrmd9q8g000009mg6vej7hzg", }, - 'annotations': [{ - 'aggregation': ScalarMetricAggregation.ARITHMETIC_MEAN, - 'value': 10.0, - 'extra': {}, - }], - 'extra': {}, - 'is_benchmark_reference': False + "annotations": [ + { + "aggregation": ScalarMetricAggregation.ARITHMETIC_MEAN, + "value": 10.0, + "extra": {}, + } + ], + "extra": {}, + "is_benchmark_reference": False, } assert label.model_dump(exclude_none=True) == expected @@ -32,100 +41,118 @@ def test_legacy_scalar_metric(): # TODO: Test with confidence -@pytest.mark.parametrize('feature_name,subclass_name,aggregation,value', [ - ("cat", "orange", ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5), - ("cat", None, ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5), - (None, None, ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5), - (None, None, None, 0.5), - ("cat", "orange", ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5), - ("cat", None, ScalarMetricAggregation.HARMONIC_MEAN, 0.5), - (None, None, ScalarMetricAggregation.GEOMETRIC_MEAN, 0.5), - (None, None, ScalarMetricAggregation.SUM, 0.5), - ("cat", "orange", ScalarMetricAggregation.ARITHMETIC_MEAN, { - 0.1: 0.2, - 0.3: 0.5, - 0.4: 0.8 - }), -]) +@pytest.mark.parametrize( + "feature_name,subclass_name,aggregation,value", + [ + ("cat", "orange", ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5), + ("cat", None, ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5), + (None, None, ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5), + (None, None, None, 0.5), + ("cat", "orange", ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5), + ("cat", None, ScalarMetricAggregation.HARMONIC_MEAN, 0.5), + (None, None, ScalarMetricAggregation.GEOMETRIC_MEAN, 0.5), + (None, None, ScalarMetricAggregation.SUM, 0.5), + ( + "cat", + "orange", + ScalarMetricAggregation.ARITHMETIC_MEAN, + {0.1: 0.2, 0.3: 0.5, 0.4: 0.8}, + ), + ], +) def test_custom_scalar_metric(feature_name, subclass_name, aggregation, value): - kwargs = {'aggregation': aggregation} if aggregation is not None else {} - metric = ScalarMetric(metric_name="custom_iou", - value=value, - feature_name=feature_name, - subclass_name=subclass_name, - **kwargs) + kwargs = {"aggregation": aggregation} if aggregation is not None else {} + metric = ScalarMetric( + metric_name="custom_iou", + value=value, + feature_name=feature_name, + subclass_name=subclass_name, + **kwargs, + ) assert metric.value == value - label = Label(data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), - annotations=[metric]) + label = Label( + data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), annotations=[metric] + ) expected = { - 'data': { - 'uid': 'ckrmd9q8g000009mg6vej7hzg', + "data": { + "uid": "ckrmd9q8g000009mg6vej7hzg", }, - 'annotations': [{ - 'value': - value, - 'metric_name': - 'custom_iou', - **({ - 'feature_name': feature_name - } if feature_name else {}), - **({ - 'subclass_name': subclass_name - } if subclass_name else {}), 'aggregation': - aggregation or ScalarMetricAggregation.ARITHMETIC_MEAN, - 'extra': {} - }], - 'extra': {}, - 'is_benchmark_reference': False + "annotations": [ + { + "value": value, + "metric_name": "custom_iou", + **({"feature_name": feature_name} if feature_name else {}), + **({"subclass_name": subclass_name} if subclass_name else {}), + "aggregation": aggregation + or ScalarMetricAggregation.ARITHMETIC_MEAN, + "extra": {}, + } + ], + "extra": {}, + "is_benchmark_reference": False, } assert label.model_dump(exclude_none=True) == expected -@pytest.mark.parametrize('feature_name,subclass_name,aggregation,value', [ - ("cat", "orange", ConfusionMatrixAggregation.CONFUSION_MATRIX, - (0, 1, 2, 3)), - ("cat", None, ConfusionMatrixAggregation.CONFUSION_MATRIX, (0, 1, 2, 3)), - (None, None, ConfusionMatrixAggregation.CONFUSION_MATRIX, (0, 1, 2, 3)), - (None, None, None, (0, 1, 2, 3)), - ("cat", "orange", ConfusionMatrixAggregation.CONFUSION_MATRIX, { - 0.1: (0, 1, 2, 3), - 0.3: (0, 1, 2, 3), - 0.4: (0, 1, 2, 3) - }), -]) -def test_custom_confusison_matrix_metric(feature_name, subclass_name, - aggregation, value): - kwargs = {'aggregation': aggregation} if aggregation is not None else {} - metric = ConfusionMatrixMetric(metric_name="confusion_matrix_50_pct_iou", - value=value, - feature_name=feature_name, - subclass_name=subclass_name, - **kwargs) +@pytest.mark.parametrize( + "feature_name,subclass_name,aggregation,value", + [ + ( + "cat", + "orange", + ConfusionMatrixAggregation.CONFUSION_MATRIX, + (0, 1, 2, 3), + ), + ( + "cat", + None, + ConfusionMatrixAggregation.CONFUSION_MATRIX, + (0, 1, 2, 3), + ), + (None, None, ConfusionMatrixAggregation.CONFUSION_MATRIX, (0, 1, 2, 3)), + (None, None, None, (0, 1, 2, 3)), + ( + "cat", + "orange", + ConfusionMatrixAggregation.CONFUSION_MATRIX, + {0.1: (0, 1, 2, 3), 0.3: (0, 1, 2, 3), 0.4: (0, 1, 2, 3)}, + ), + ], +) +def test_custom_confusison_matrix_metric( + feature_name, subclass_name, aggregation, value +): + kwargs = {"aggregation": aggregation} if aggregation is not None else {} + metric = ConfusionMatrixMetric( + metric_name="confusion_matrix_50_pct_iou", + value=value, + feature_name=feature_name, + subclass_name=subclass_name, + **kwargs, + ) assert metric.value == value - label = Label(data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), - annotations=[metric]) + label = Label( + data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), annotations=[metric] + ) expected = { - 'data': { - 'uid': 'ckrmd9q8g000009mg6vej7hzg', + "data": { + "uid": "ckrmd9q8g000009mg6vej7hzg", }, - 'annotations': [{ - 'value': - value, - 'metric_name': - 'confusion_matrix_50_pct_iou', - **({ - 'feature_name': feature_name - } if feature_name else {}), - **({ - 'subclass_name': subclass_name - } if subclass_name else {}), 'aggregation': - aggregation or ConfusionMatrixAggregation.CONFUSION_MATRIX, - 'extra': {} - }], - 'extra': {}, - 'is_benchmark_reference': False + "annotations": [ + { + "value": value, + "metric_name": "confusion_matrix_50_pct_iou", + **({"feature_name": feature_name} if feature_name else {}), + **({"subclass_name": subclass_name} if subclass_name else {}), + "aggregation": aggregation + or ConfusionMatrixAggregation.CONFUSION_MATRIX, + "extra": {}, + } + ], + "extra": {}, + "is_benchmark_reference": False, } assert label.model_dump(exclude_none=True) == expected @@ -141,11 +168,14 @@ def test_invalid_aggregations(): metric = ScalarMetric( metric_name="invalid aggregation", value=0.1, - aggregation=ConfusionMatrixAggregation.CONFUSION_MATRIX) + aggregation=ConfusionMatrixAggregation.CONFUSION_MATRIX, + ) with pytest.raises(ValidationError) as exc_info: - metric = ConfusionMatrixMetric(metric_name="invalid aggregation", - value=[0, 1, 2, 3], - aggregation=ScalarMetricAggregation.SUM) + metric = ConfusionMatrixMetric( + metric_name="invalid aggregation", + value=[0, 1, 2, 3], + aggregation=ScalarMetricAggregation.SUM, + ) def test_invalid_number_of_confidence_scores(): @@ -153,17 +183,21 @@ def test_invalid_number_of_confidence_scores(): metric = ScalarMetric(metric_name="too few scores", value={0.1: 0.1}) assert "Number of confidence scores must be greater" in str(exc_info.value) with pytest.raises(ValidationError) as exc_info: - metric = ConfusionMatrixMetric(metric_name="too few scores", - value={0.1: [0, 1, 2, 3]}) + metric = ConfusionMatrixMetric( + metric_name="too few scores", value={0.1: [0, 1, 2, 3]} + ) assert "Number of confidence scores must be greater" in str(exc_info.value) with pytest.raises(ValidationError) as exc_info: - metric = ScalarMetric(metric_name="too many scores", - value={i / 20.: 0.1 for i in range(20)}) + metric = ScalarMetric( + metric_name="too many scores", + value={i / 20.0: 0.1 for i in range(20)}, + ) assert "Number of confidence scores must be greater" in str(exc_info.value) with pytest.raises(ValidationError) as exc_info: metric = ConfusionMatrixMetric( metric_name="too many scores", - value={i / 20.: [0, 1, 2, 3] for i in range(20)}) + value={i / 20.0: [0, 1, 2, 3] for i in range(20)}, + ) assert "Number of confidence scores must be greater" in str(exc_info.value) @@ -171,4 +205,4 @@ def test_invalid_number_of_confidence_scores(): def test_reserved_names(metric_name: str): with pytest.raises(ValidationError) as exc_info: ScalarMetric(metric_name=metric_name, value=0.5) - assert 'is a reserved metric name' in exc_info.value.errors()[0]['msg'] + assert "is a reserved metric name" in exc_info.value.errors()[0]["msg"] diff --git a/libs/labelbox/tests/data/annotation_types/test_ner.py b/libs/labelbox/tests/data/annotation_types/test_ner.py index 9619689b1..32f40472e 100644 --- a/libs/labelbox/tests/data/annotation_types/test_ner.py +++ b/libs/labelbox/tests/data/annotation_types/test_ner.py @@ -1,5 +1,11 @@ -from labelbox.data.annotation_types import TextEntity, DocumentEntity, DocumentTextSelection -from labelbox.data.annotation_types.ner.conversation_entity import ConversationEntity +from labelbox.data.annotation_types import ( + TextEntity, + DocumentEntity, + DocumentTextSelection, +) +from labelbox.data.annotation_types.ner.conversation_entity import ( + ConversationEntity, +) def test_ner(): @@ -11,9 +17,11 @@ def test_ner(): def test_document_entity(): - document_entity = DocumentEntity(text_selections=[ - DocumentTextSelection(token_ids=["1", "2"], group_id="1", page=1) - ]) + document_entity = DocumentEntity( + text_selections=[ + DocumentTextSelection(token_ids=["1", "2"], group_id="1", page=1) + ] + ) assert document_entity.text_selections[0].token_ids == ["1", "2"] assert document_entity.text_selections[0].group_id == "1" diff --git a/libs/labelbox/tests/data/annotation_types/test_tiled_image.py b/libs/labelbox/tests/data/annotation_types/test_tiled_image.py index aea6587f6..46f2383d6 100644 --- a/libs/labelbox/tests/data/annotation_types/test_tiled_image.py +++ b/libs/labelbox/tests/data/annotation_types/test_tiled_image.py @@ -3,10 +3,13 @@ from labelbox.data.annotation_types.geometry.point import Point from labelbox.data.annotation_types.geometry.line import Line from labelbox.data.annotation_types.geometry.rectangle import Rectangle -from labelbox.data.annotation_types.data.tiled_image import (EPSG, TiledBounds, - TileLayer, - TiledImageData, - EPSGTransformer) +from labelbox.data.annotation_types.data.tiled_image import ( + EPSG, + TiledBounds, + TileLayer, + TiledImageData, + EPSGTransformer, +) from pydantic import ValidationError @@ -29,21 +32,26 @@ def test_tiled_bounds(epsg): def test_tiled_bounds_same(epsg): single_bound = Point(x=0, y=0) with pytest.raises(ValidationError): - tiled_bounds = TiledBounds(epsg=epsg, - bounds=[single_bound, single_bound]) + tiled_bounds = TiledBounds( + epsg=epsg, bounds=[single_bound, single_bound] + ) def test_create_tiled_image_data(): bounds_points = [Point(x=0, y=0), Point(x=5, y=5)] - url = "https://labelbox.s3-us-west-2.amazonaws.com/pathology/{z}/{x}/{y}.png" + url = ( + "https://labelbox.s3-us-west-2.amazonaws.com/pathology/{z}/{x}/{y}.png" + ) zoom_levels = (1, 10) tile_layer = TileLayer(url=url, name="slippy map tile") tile_bounds = TiledBounds(epsg=EPSG.EPSG4326, bounds=bounds_points) - tiled_image_data = TiledImageData(tile_layer=tile_layer, - tile_bounds=tile_bounds, - zoom_levels=zoom_levels, - version=2) + tiled_image_data = TiledImageData( + tile_layer=tile_layer, + tile_bounds=tile_bounds, + zoom_levels=zoom_levels, + version=2, + ) assert isinstance(tiled_image_data, TiledImageData) assert tiled_image_data.tile_bounds.bounds == bounds_points assert tiled_image_data.tile_layer.url == url @@ -53,20 +61,24 @@ def test_create_tiled_image_data(): def test_epsg_point_projections(): zoom = 4 - bounds_simple = TiledBounds(epsg=EPSG.SIMPLEPIXEL, - bounds=[Point(x=0, y=0), - Point(x=256, y=256)]) - - bounds_3857 = TiledBounds(epsg=EPSG.EPSG3857, - bounds=[ - Point(x=-104.150390625, y=30.789036751261136), - Point(x=-81.8701171875, y=45.920587344733654) - ]) - bounds_4326 = TiledBounds(epsg=EPSG.EPSG4326, - bounds=[ - Point(x=-104.150390625, y=30.789036751261136), - Point(x=-81.8701171875, y=45.920587344733654) - ]) + bounds_simple = TiledBounds( + epsg=EPSG.SIMPLEPIXEL, bounds=[Point(x=0, y=0), Point(x=256, y=256)] + ) + + bounds_3857 = TiledBounds( + epsg=EPSG.EPSG3857, + bounds=[ + Point(x=-104.150390625, y=30.789036751261136), + Point(x=-81.8701171875, y=45.920587344733654), + ], + ) + bounds_4326 = TiledBounds( + epsg=EPSG.EPSG4326, + bounds=[ + Point(x=-104.150390625, y=30.789036751261136), + Point(x=-81.8701171875, y=45.920587344733654), + ], + ) point = Point(x=-11016716.012685884, y=5312679.21393289) point_two = Point(x=-12016716.012685884, y=5212679.21393289) @@ -82,7 +94,8 @@ def test_epsg_point_projections(): src_epsg=EPSG.EPSG3857, pixel_bounds=bounds_simple, geo_bounds=bounds_3857, - zoom=zoom) + zoom=zoom, + ) transformer_3857_4326 = EPSGTransformer.create_geo_to_geo_transformer( src_epsg=EPSG.EPSG3857, tgt_epsg=EPSG.EPSG4326, @@ -91,7 +104,8 @@ def test_epsg_point_projections(): src_epsg=EPSG.EPSG4326, pixel_bounds=bounds_simple, geo_bounds=bounds_4326, - zoom=zoom) + zoom=zoom, + ) for shape in shapes_to_test: shape_simple = transformer_3857_simple(shape=shape) diff --git a/libs/labelbox/tests/data/annotation_types/test_video.py b/libs/labelbox/tests/data/annotation_types/test_video.py index f61dc7ec7..4b92e161d 100644 --- a/libs/labelbox/tests/data/annotation_types/test_video.py +++ b/libs/labelbox/tests/data/annotation_types/test_video.py @@ -2,18 +2,19 @@ def test_mask_frame(): - mask_frame = lb_types.MaskFrame(index=1, - instance_uri="http://path/to/frame.png") + mask_frame = lb_types.MaskFrame( + index=1, instance_uri="http://path/to/frame.png" + ) assert mask_frame.model_dump(by_alias=True) == { - 'index': 1, - 'imBytes': None, - 'instanceURI': 'http://path/to/frame.png' + "index": 1, + "imBytes": None, + "instanceURI": "http://path/to/frame.png", } def test_mask_instance(): mask_instance = lb_types.MaskInstance(color_rgb=(0, 0, 255), name="mask1") assert mask_instance.model_dump(by_alias=True, exclude_none=True) == { - 'colorRGB': (0, 0, 255), - 'name': 'mask1' + "colorRGB": (0, 0, 255), + "name": "mask1", } diff --git a/libs/labelbox/tests/data/conftest.py b/libs/labelbox/tests/data/conftest.py index 07f3460b8..aa1379407 100644 --- a/libs/labelbox/tests/data/conftest.py +++ b/libs/labelbox/tests/data/conftest.py @@ -1,6 +1,11 @@ import pytest -from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnnotation, ClassificationAnswer, Radio +from labelbox.data.annotation_types.classification.classification import ( + Checklist, + ClassificationAnnotation, + ClassificationAnswer, + Radio, +) from labelbox.data.annotation_types.geometry.point import Point from labelbox.data.annotation_types.geometry.rectangle import Rectangle from labelbox.data.annotation_types.video import VideoObjectAnnotation @@ -20,21 +25,30 @@ def bbox_video_annotation_objects(): ), classifications=[ ClassificationAnnotation( - name='nested', - value=Radio(answer=ClassificationAnswer( - name='radio_option_1', - classifications=[ - ClassificationAnnotation( - name='nested_checkbox', - value=Checklist(answer=[ - ClassificationAnswer( - name='nested_checkbox_option_1'), - ClassificationAnswer( - name='nested_checkbox_option_2') - ])) - ])), + name="nested", + value=Radio( + answer=ClassificationAnswer( + name="radio_option_1", + classifications=[ + ClassificationAnnotation( + name="nested_checkbox", + value=Checklist( + answer=[ + ClassificationAnswer( + name="nested_checkbox_option_1" + ), + ClassificationAnswer( + name="nested_checkbox_option_2" + ), + ] + ), + ) + ], + ) + ), ) - ]), + ], + ), VideoObjectAnnotation( name="bbox", keyframe=True, @@ -43,7 +57,8 @@ def bbox_video_annotation_objects(): value=Rectangle( start=Point(x=186.0, y=98.0), # Top left end=Point(x=490.0, y=341.0), # Bottom right - )) + ), + ), ] return bbox_annotation diff --git a/libs/labelbox/tests/data/export/conftest.py b/libs/labelbox/tests/data/export/conftest.py index 104ee41dc..0836c2b9e 100644 --- a/libs/labelbox/tests/data/export/conftest.py +++ b/libs/labelbox/tests/data/export/conftest.py @@ -10,225 +10,196 @@ @pytest.fixture def ontology(): bbox_tool_with_nested_text = { - 'required': - False, - 'name': - 'bbox_tool_with_nested_text', - 'tool': - 'rectangle', - 'color': - '#a23030', - 'classifications': [{ - 'required': - False, - 'instructions': - 'nested', - 'name': - 'nested', - 'type': - 'radio', - 'options': [{ - 'label': - 'radio_option_1', - 'value': - 'radio_value_1', - 'options': [{ - 'required': - False, - 'instructions': - 'nested_checkbox', - 'name': - 'nested_checkbox', - 'type': - 'checklist', - 'options': [{ - 'label': 'nested_checkbox_option_1', - 'value': 'nested_checkbox_value_1', - 'options': [] - }, { - 'label': 'nested_checkbox_option_2', - 'value': 'nested_checkbox_value_2' - }] - }, { - 'required': False, - 'instructions': 'nested_text', - 'name': 'nested_text', - 'type': 'text', - 'options': [] - }] - },] - }] + "required": False, + "name": "bbox_tool_with_nested_text", + "tool": "rectangle", + "color": "#a23030", + "classifications": [ + { + "required": False, + "instructions": "nested", + "name": "nested", + "type": "radio", + "options": [ + { + "label": "radio_option_1", + "value": "radio_value_1", + "options": [ + { + "required": False, + "instructions": "nested_checkbox", + "name": "nested_checkbox", + "type": "checklist", + "options": [ + { + "label": "nested_checkbox_option_1", + "value": "nested_checkbox_value_1", + "options": [], + }, + { + "label": "nested_checkbox_option_2", + "value": "nested_checkbox_value_2", + }, + ], + }, + { + "required": False, + "instructions": "nested_text", + "name": "nested_text", + "type": "text", + "options": [], + }, + ], + }, + ], + } + ], } bbox_tool = { - 'required': - False, - 'name': - 'bbox', - 'tool': - 'rectangle', - 'color': - '#a23030', - 'classifications': [{ - 'required': - False, - 'instructions': - 'nested', - 'name': - 'nested', - 'type': - 'radio', - 'options': [{ - 'label': - 'radio_option_1', - 'value': - 'radio_value_1', - 'options': [{ - 'required': - False, - 'instructions': - 'nested_checkbox', - 'name': - 'nested_checkbox', - 'type': - 'checklist', - 'options': [{ - 'label': 'nested_checkbox_option_1', - 'value': 'nested_checkbox_value_1', - 'options': [] - }, { - 'label': 'nested_checkbox_option_2', - 'value': 'nested_checkbox_value_2' - }] - }] - },] - }] + "required": False, + "name": "bbox", + "tool": "rectangle", + "color": "#a23030", + "classifications": [ + { + "required": False, + "instructions": "nested", + "name": "nested", + "type": "radio", + "options": [ + { + "label": "radio_option_1", + "value": "radio_value_1", + "options": [ + { + "required": False, + "instructions": "nested_checkbox", + "name": "nested_checkbox", + "type": "checklist", + "options": [ + { + "label": "nested_checkbox_option_1", + "value": "nested_checkbox_value_1", + "options": [], + }, + { + "label": "nested_checkbox_option_2", + "value": "nested_checkbox_value_2", + }, + ], + } + ], + }, + ], + } + ], } polygon_tool = { - 'required': False, - 'name': 'polygon', - 'tool': 'polygon', - 'color': '#FF34FF', - 'classifications': [] + "required": False, + "name": "polygon", + "tool": "polygon", + "color": "#FF34FF", + "classifications": [], } polyline_tool = { - 'required': False, - 'name': 'polyline', - 'tool': 'line', - 'color': '#FF4A46', - 'classifications': [] + "required": False, + "name": "polyline", + "tool": "line", + "color": "#FF4A46", + "classifications": [], } point_tool = { - 'required': False, - 'name': 'point--', - 'tool': 'point', - 'color': '#008941', - 'classifications': [] + "required": False, + "name": "point--", + "tool": "point", + "color": "#008941", + "classifications": [], } entity_tool = { - 'required': False, - 'name': 'entity--', - 'tool': 'named-entity', - 'color': '#006FA6', - 'classifications': [] + "required": False, + "name": "entity--", + "tool": "named-entity", + "color": "#006FA6", + "classifications": [], } segmentation_tool = { - 'required': False, - 'name': 'segmentation--', - 'tool': 'superpixel', - 'color': '#A30059', - 'classifications': [] + "required": False, + "name": "segmentation--", + "tool": "superpixel", + "color": "#A30059", + "classifications": [], } raster_segmentation_tool = { - 'required': False, - 'name': 'segmentation_mask', - 'tool': 'raster-segmentation', - 'color': '#ff0000', - 'classifications': [] + "required": False, + "name": "segmentation_mask", + "tool": "raster-segmentation", + "color": "#ff0000", + "classifications": [], } checklist = { - 'required': - False, - 'instructions': - 'checklist', - 'name': - 'checklist', - 'type': - 'checklist', - 'options': [{ - 'label': 'option1', - 'value': 'option1' - }, { - 'label': 'option2', - 'value': 'option2' - }, { - 'label': 'optionN', - 'value': 'optionn' - }] + "required": False, + "instructions": "checklist", + "name": "checklist", + "type": "checklist", + "options": [ + {"label": "option1", "value": "option1"}, + {"label": "option2", "value": "option2"}, + {"label": "optionN", "value": "optionn"}, + ], } checklist_index = { - 'required': - False, - 'instructions': - 'checklist_index', - 'name': - 'checklist_index', - 'type': - 'checklist', - 'scope': - 'index', - 'options': [{ - 'label': 'option1_index', - 'value': 'option1_index' - }, { - 'label': 'option2_index', - 'value': 'option2_index' - }, { - 'label': 'optionN_index', - 'value': 'optionn_index' - }] + "required": False, + "instructions": "checklist_index", + "name": "checklist_index", + "type": "checklist", + "scope": "index", + "options": [ + {"label": "option1_index", "value": "option1_index"}, + {"label": "option2_index", "value": "option2_index"}, + {"label": "optionN_index", "value": "optionn_index"}, + ], } free_form_text = { - 'required': False, - 'instructions': 'text', - 'name': 'text', - 'type': 'text', - 'options': [] + "required": False, + "instructions": "text", + "name": "text", + "type": "text", + "options": [], } free_form_text_index = { - 'required': False, - 'instructions': 'text_index', - 'name': 'text_index', - 'type': 'text', - 'scope': 'index', - 'options': [] + "required": False, + "instructions": "text_index", + "name": "text_index", + "type": "text", + "scope": "index", + "options": [], } radio = { - 'required': - False, - 'instructions': - 'radio', - 'name': - 'radio', - 'type': - 'radio', - 'options': [{ - 'label': 'first_radio_answer', - 'value': 'first_radio_answer', - 'options': [] - }, { - 'label': 'second_radio_answer', - 'value': 'second_radio_answer', - 'options': [] - }] + "required": False, + "instructions": "radio", + "name": "radio", + "type": "radio", + "options": [ + { + "label": "first_radio_answer", + "value": "first_radio_answer", + "options": [], + }, + { + "label": "second_radio_answer", + "value": "second_radio_answer", + "options": [], + }, + ], } named_entity = { - 'tool': 'named-entity', - 'name': 'named-entity', - 'required': False, - 'color': '#A30059', - 'classifications': [], + "tool": "named-entity", + "name": "named-entity", + "required": False, + "color": "#A30059", + "classifications": [], } tools = [ @@ -243,53 +214,53 @@ def ontology(): named_entity, ] classifications = [ - checklist, checklist_index, free_form_text, free_form_text_index, radio + checklist, + checklist_index, + free_form_text, + free_form_text_index, + radio, ] return {"tools": tools, "classifications": classifications} @pytest.fixture def polygon_inference(prediction_id_mapping): - polygon = prediction_id_mapping['polygon'].copy() - polygon.update({ - "polygon": [{ - "x": 147.692, - "y": 118.154 - }, { - "x": 142.769, - "y": 104.923 - }, { - "x": 57.846, - "y": 118.769 - }, { - "x": 28.308, - "y": 169.846 - }] - }) - del polygon['tool'] + polygon = prediction_id_mapping["polygon"].copy() + polygon.update( + { + "polygon": [ + {"x": 147.692, "y": 118.154}, + {"x": 142.769, "y": 104.923}, + {"x": 57.846, "y": 118.769}, + {"x": 28.308, "y": 169.846}, + ] + } + ) + del polygon["tool"] return polygon @pytest.fixture -def configured_project_with_ontology(client, initial_dataset, ontology, - rand_gen, image_url): +def configured_project_with_ontology( + client, initial_dataset, ontology, rand_gen, image_url +): dataset = initial_dataset project = client.create_project( name=rand_gen(str), queue_mode=QueueMode.Batch, ) editor = list( - client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] + client.get_labeling_frontends(where=LabelingFrontend.name == "editor") + )[0] project.setup(editor, ontology) data_row_ids = [] - for _ in range(len(ontology['tools']) + len(ontology['classifications'])): + for _ in range(len(ontology["tools"]) + len(ontology["classifications"])): data_row_ids.append(dataset.create_data_row(row_data=image_url).uid) project.create_batch( rand_gen(str), data_row_ids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) + 5, # priority between 1(Highest) - 5(lowest) ) project.data_row_ids = data_row_ids yield project @@ -298,33 +269,44 @@ def configured_project_with_ontology(client, initial_dataset, ontology, @pytest.fixture def configured_project_without_data_rows(client, ontology, rand_gen): - project = client.create_project(name=rand_gen(str), - description=rand_gen(str), - queue_mode=QueueMode.Batch) + project = client.create_project( + name=rand_gen(str), + description=rand_gen(str), + queue_mode=QueueMode.Batch, + ) editor = list( - client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] + client.get_labeling_frontends(where=LabelingFrontend.name == "editor") + )[0] project.setup(editor, ontology) yield project project.delete() @pytest.fixture -def model_run_with_data_rows(client, configured_project_with_ontology, - model_run_predictions, model_run, - wait_for_label_processing): +def model_run_with_data_rows( + client, + configured_project_with_ontology, + model_run_predictions, + model_run, + wait_for_label_processing, +): configured_project_with_ontology.enable_model_assisted_labeling() - use_data_row_ids = [p['dataRow']['id'] for p in model_run_predictions] + use_data_row_ids = [p["dataRow"]["id"] for p in model_run_predictions] model_run.upsert_data_rows(use_data_row_ids) upload_task = LabelImport.create_from_objects( - client, configured_project_with_ontology.uid, - f"label-import-{uuid.uuid4()}", model_run_predictions) + client, + configured_project_with_ontology.uid, + f"label-import-{uuid.uuid4()}", + model_run_predictions, + ) upload_task.wait_until_done() - assert upload_task.state == AnnotationImportState.FINISHED, "Label Import did not finish" - assert len( - upload_task.errors - ) == 0, f"Label Import {upload_task.name} failed with errors {upload_task.errors}" + assert ( + upload_task.state == AnnotationImportState.FINISHED + ), "Label Import did not finish" + assert ( + len(upload_task.errors) == 0 + ), f"Label Import {upload_task.name} failed with errors {upload_task.errors}" labels = wait_for_label_processing(configured_project_with_ontology) label_ids = [label.uid for label in labels] model_run.upsert_labels(label_ids) @@ -334,8 +316,9 @@ def model_run_with_data_rows(client, configured_project_with_ontology, @pytest.fixture -def model_run_predictions(polygon_inference, rectangle_inference, - line_inference): +def model_run_predictions( + polygon_inference, rectangle_inference, line_inference +): # Not supporting mask since there isn't a signed url representing a seg mask to upload return [polygon_inference, rectangle_inference, line_inference] @@ -398,23 +381,26 @@ def prediction_id_mapping(configured_project_with_ontology): ontology = project.ontology().normalized result = {} - for idx, tool in enumerate(ontology['tools'] + ontology['classifications']): - if 'tool' in tool: - tool_type = tool['tool'] + for idx, tool in enumerate(ontology["tools"] + ontology["classifications"]): + if "tool" in tool: + tool_type = tool["tool"] else: - tool_type = tool[ - 'type'] if 'scope' not in tool else f"{tool['type']}_{tool['scope']}" # so 'checklist' of 'checklist_index' + tool_type = ( + tool["type"] + if "scope" not in tool + else f"{tool['type']}_{tool['scope']}" + ) # so 'checklist' of 'checklist_index' # TODO: remove this once we have a better way to associate multiple tools instances with a single tool type - if tool_type == 'rectangle': + if tool_type == "rectangle": value = { "uuid": str(uuid.uuid4()), - "schemaId": tool['featureSchemaId'], - "name": tool['name'], + "schemaId": tool["featureSchemaId"], + "name": tool["name"], "dataRow": { "id": project.data_row_ids[idx], }, - 'tool': tool + "tool": tool, } if tool_type not in result: result[tool_type] = [] @@ -422,86 +408,76 @@ def prediction_id_mapping(configured_project_with_ontology): else: result[tool_type] = { "uuid": str(uuid.uuid4()), - "schemaId": tool['featureSchemaId'], - "name": tool['name'], + "schemaId": tool["featureSchemaId"], + "name": tool["name"], "dataRow": { "id": project.data_row_ids[idx], }, - 'tool': tool + "tool": tool, } return result @pytest.fixture def line_inference(prediction_id_mapping): - line = prediction_id_mapping['line'].copy() + line = prediction_id_mapping["line"].copy() line.update( - {"line": [{ - "x": 147.692, - "y": 118.154 - }, { - "x": 150.692, - "y": 160.154 - }]}) - del line['tool'] + {"line": [{"x": 147.692, "y": 118.154}, {"x": 150.692, "y": 160.154}]} + ) + del line["tool"] return line @pytest.fixture def polygon_inference(prediction_id_mapping): - polygon = prediction_id_mapping['polygon'].copy() - polygon.update({ - "polygon": [{ - "x": 147.692, - "y": 118.154 - }, { - "x": 142.769, - "y": 104.923 - }, { - "x": 57.846, - "y": 118.769 - }, { - "x": 28.308, - "y": 169.846 - }] - }) - del polygon['tool'] + polygon = prediction_id_mapping["polygon"].copy() + polygon.update( + { + "polygon": [ + {"x": 147.692, "y": 118.154}, + {"x": 142.769, "y": 104.923}, + {"x": 57.846, "y": 118.769}, + {"x": 28.308, "y": 169.846}, + ] + } + ) + del polygon["tool"] return polygon def find_tool_by_name(tool_instances, name): for tool in tool_instances: - if tool['name'] == name: + if tool["name"] == name: return tool return None @pytest.fixture def rectangle_inference(prediction_id_mapping): - tool_instance = find_tool_by_name(prediction_id_mapping['rectangle'], - 'bbox') + tool_instance = find_tool_by_name( + prediction_id_mapping["rectangle"], "bbox" + ) rectangle = tool_instance.copy() - rectangle.update({ - "bbox": { - "top": 48, - "left": 58, - "height": 65, - "width": 12 - }, - 'classifications': [{ - "schemaId": - rectangle['tool']['classifications'][0]['featureSchemaId'], - "name": - rectangle['tool']['classifications'][0]['name'], - "answer": { - "schemaId": - rectangle['tool']['classifications'][0]['options'][0] - ['featureSchemaId'], - "name": - rectangle['tool']['classifications'][0]['options'][0] - ['value'] - } - }] - }) - del rectangle['tool'] + rectangle.update( + { + "bbox": {"top": 48, "left": 58, "height": 65, "width": 12}, + "classifications": [ + { + "schemaId": rectangle["tool"]["classifications"][0][ + "featureSchemaId" + ], + "name": rectangle["tool"]["classifications"][0]["name"], + "answer": { + "schemaId": rectangle["tool"]["classifications"][0][ + "options" + ][0]["featureSchemaId"], + "name": rectangle["tool"]["classifications"][0][ + "options" + ][0]["value"], + }, + } + ], + } + ) + del rectangle["tool"] return rectangle diff --git a/libs/labelbox/tests/data/export/legacy/test_export_catalog.py b/libs/labelbox/tests/data/export/legacy/test_export_catalog.py index b5aa72a35..635d307f0 100644 --- a/libs/labelbox/tests/data/export/legacy/test_export_catalog.py +++ b/libs/labelbox/tests/data/export/legacy/test_export_catalog.py @@ -1,7 +1,7 @@ import pytest -@pytest.mark.parametrize('data_rows', [3], indirect=True) +@pytest.mark.parametrize("data_rows", [3], indirect=True) def test_catalog_export_v2(client, export_v2_test_helpers, data_rows): datarow_filter_size = 2 data_row_ids = [dr.uid for dr in data_rows] @@ -10,10 +10,12 @@ def test_catalog_export_v2(client, export_v2_test_helpers, data_rows): filters = {"data_row_ids": data_row_ids[:datarow_filter_size]} task_results = export_v2_test_helpers.run_catalog_export_v2_task( - client, filters=filters, params=params) + client, filters=filters, params=params + ) # only 2 datarows should be exported assert len(task_results) == datarow_filter_size # only filtered datarows should be exported - assert set([dr['data_row']['id'] for dr in task_results - ]) == set(data_row_ids[:datarow_filter_size]) + assert set([dr["data_row"]["id"] for dr in task_results]) == set( + data_row_ids[:datarow_filter_size] + ) diff --git a/libs/labelbox/tests/data/export/legacy/test_export_dataset.py b/libs/labelbox/tests/data/export/legacy/test_export_dataset.py index e4a0b50c2..1d628dc86 100644 --- a/libs/labelbox/tests/data/export/legacy/test_export_dataset.py +++ b/libs/labelbox/tests/data/export/legacy/test_export_dataset.py @@ -1,15 +1,17 @@ import pytest -@pytest.mark.parametrize('data_rows', [3], indirect=True) +@pytest.mark.parametrize("data_rows", [3], indirect=True) def test_dataset_export_v2(export_v2_test_helpers, dataset, data_rows): data_row_ids = [dr.uid for dr in data_rows] params = {"performance_details": False, "label_details": False} task_results = export_v2_test_helpers.run_dataset_export_v2_task( - dataset, params=params) + dataset, params=params + ) assert len(task_results) == len(data_row_ids) - assert set([dr['data_row']['id'] for dr in task_results - ]) == set(data_row_ids) + assert set([dr["data_row"]["id"] for dr in task_results]) == set( + data_row_ids + ) # testing with a datarow ids filter datarow_filter_size = 2 @@ -19,13 +21,15 @@ def test_dataset_export_v2(export_v2_test_helpers, dataset, data_rows): filters = {"data_row_ids": data_row_ids[:datarow_filter_size]} task_results = export_v2_test_helpers.run_dataset_export_v2_task( - dataset, filters=filters, params=params) + dataset, filters=filters, params=params + ) # only 2 datarows should be exported assert len(task_results) == datarow_filter_size # only filtered datarows should be exported - assert set([dr['data_row']['id'] for dr in task_results - ]) == set(data_row_ids[:datarow_filter_size]) + assert set([dr["data_row"]["id"] for dr in task_results]) == set( + data_row_ids[:datarow_filter_size] + ) # testing with a global key and a datarow id filter datarow_filter_size = 2 @@ -35,10 +39,12 @@ def test_dataset_export_v2(export_v2_test_helpers, dataset, data_rows): filters = {"global_keys": global_keys[:datarow_filter_size]} task_results = export_v2_test_helpers.run_dataset_export_v2_task( - dataset, filters=filters, params=params) + dataset, filters=filters, params=params + ) # only 2 datarows should be exported assert len(task_results) == datarow_filter_size # only filtered datarows should be exported - assert set([dr['data_row']['global_key'] for dr in task_results - ]) == set(global_keys[:datarow_filter_size]) + assert set([dr["data_row"]["global_key"] for dr in task_results]) == set( + global_keys[:datarow_filter_size] + ) diff --git a/libs/labelbox/tests/data/export/legacy/test_export_model_run.py b/libs/labelbox/tests/data/export/legacy/test_export_model_run.py index 7dfd44f0c..2a06c334d 100644 --- a/libs/labelbox/tests/data/export/legacy/test_export_model_run.py +++ b/libs/labelbox/tests/data/export/legacy/test_export_model_run.py @@ -3,7 +3,7 @@ def _model_run_export_v2_results(model_run, task_name, params, num_retries=5): """Export model run results and retry if no results are returned.""" - while (num_retries > 0): + while num_retries > 0: task = model_run.export_v2(task_name, params=params) assert task.name == task_name task.wait_till_done() @@ -30,15 +30,22 @@ def test_model_run_export_v2(model_run_with_data_rows): for task_result in task_results: # Check export param handling - assert 'media_attributes' in task_result and task_result[ - 'media_attributes'] is not None - exported_model_run = task_result['experiments'][ - model_run.model_id]['runs'][model_run.uid] + assert ( + "media_attributes" in task_result + and task_result["media_attributes"] is not None + ) + exported_model_run = task_result["experiments"][model_run.model_id][ + "runs" + ][model_run.uid] task_label_ids_set = set( - map(lambda label: label['id'], exported_model_run['labels'])) + map(lambda label: label["id"], exported_model_run["labels"]) + ) task_prediction_ids_set = set( - map(lambda prediction: prediction['id'], - exported_model_run['predictions'])) + map( + lambda prediction: prediction["id"], + exported_model_run["predictions"], + ) + ) for label_id in task_label_ids_set: assert label_id in label_ids for prediction_id in task_prediction_ids_set: diff --git a/libs/labelbox/tests/data/export/legacy/test_export_project.py b/libs/labelbox/tests/data/export/legacy/test_export_project.py index f7716d5c5..3cd3b9226 100644 --- a/libs/labelbox/tests/data/export/legacy/test_export_project.py +++ b/libs/labelbox/tests/data/export/legacy/test_export_project.py @@ -10,9 +10,12 @@ IMAGE_URL = "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/potato.jpeg" -def test_project_export_v2(client, export_v2_test_helpers, - configured_project_with_label, - wait_for_data_row_processing): +def test_project_export_v2( + client, + export_v2_test_helpers, + configured_project_with_label, + wait_for_data_row_processing, +): project, dataset, data_row, label = configured_project_with_label data_row = wait_for_data_row_processing(client, data_row) label_id = label.uid @@ -23,55 +26,63 @@ def test_project_export_v2(client, export_v2_test_helpers, "include_labels": True, "media_type_override": MediaType.Image, "project_details": True, - "data_row_details": True + "data_row_details": True, } task_results = export_v2_test_helpers.run_project_export_v2_task( - project, task_name=task_name, params=params) + project, task_name=task_name, params=params + ) for task_result in task_results: - task_media_attributes = task_result['media_attributes'] - task_project = task_result['projects'][project.uid] + task_media_attributes = task_result["media_attributes"] + task_project = task_result["projects"][project.uid] task_project_label_ids_set = set( - map(lambda prediction: prediction['id'], task_project['labels'])) - task_project_details = task_project['project_details'] - task_data_row = task_result['data_row'] - task_data_row_details = task_data_row['details'] + map(lambda prediction: prediction["id"], task_project["labels"]) + ) + task_project_details = task_project["project_details"] + task_data_row = task_result["data_row"] + task_data_row_details = task_data_row["details"] assert label_id in task_project_label_ids_set # data row - assert task_data_row['id'] == data_row.uid - assert task_data_row['external_id'] == data_row.external_id - assert task_data_row['row_data'] == data_row.row_data + assert task_data_row["id"] == data_row.uid + assert task_data_row["external_id"] == data_row.external_id + assert task_data_row["row_data"] == data_row.row_data # data row details - assert task_data_row_details['dataset_id'] == dataset.uid - assert task_data_row_details['dataset_name'] == dataset.name + assert task_data_row_details["dataset_id"] == dataset.uid + assert task_data_row_details["dataset_name"] == dataset.name - assert task_data_row_details['last_activity_at'] is not None - assert task_data_row_details['created_by'] is not None + assert task_data_row_details["last_activity_at"] is not None + assert task_data_row_details["created_by"] is not None # media attributes - assert task_media_attributes['mime_type'] == data_row.media_attributes[ - 'mimeType'] + assert ( + task_media_attributes["mime_type"] + == data_row.media_attributes["mimeType"] + ) # project name and details - assert task_project['name'] == project.name + assert task_project["name"] == project.name batch = next(project.batches()) - assert task_project_details['batch_id'] == batch.uid - assert task_project_details['batch_name'] == batch.name - assert task_project_details['priority'] is not None - assert task_project_details[ - 'consensus_expected_label_count'] is not None - assert task_project_details['workflow_history'] is not None + assert task_project_details["batch_id"] == batch.uid + assert task_project_details["batch_name"] == batch.name + assert task_project_details["priority"] is not None + assert ( + task_project_details["consensus_expected_label_count"] is not None + ) + assert task_project_details["workflow_history"] is not None # label details - assert task_project['labels'][0]['id'] == label_id + assert task_project["labels"][0]["id"] == label_id -def test_project_export_v2_date_filters(client, export_v2_test_helpers, - configured_project_with_label, - wait_for_data_row_processing): +def test_project_export_v2_date_filters( + client, + export_v2_test_helpers, + configured_project_with_label, + wait_for_data_row_processing, +): project, _, data_row, label = configured_project_with_label data_row = wait_for_data_row_processing(client, data_row) label_id = label.uid @@ -81,7 +92,7 @@ def test_project_export_v2_date_filters(client, export_v2_test_helpers, filters = { "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "task_queue_status": "InReview" + "task_queue_status": "InReview", } # TODO: Right now we don't have a way to test this @@ -90,24 +101,27 @@ def test_project_export_v2_date_filters(client, export_v2_test_helpers, "performance_details": include_performance_details, "include_labels": True, "project_details": True, - "media_type_override": MediaType.Image + "media_type_override": MediaType.Image, } task_queues = project.task_queues() review_queue = next( - tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE") + tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE" + ) project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid) task_results = export_v2_test_helpers.run_project_export_v2_task( - project, task_name=task_name, filters=filters, params=params) + project, task_name=task_name, filters=filters, params=params + ) for task_result in task_results: - task_project = task_result['projects'][project.uid] + task_project = task_result["projects"][project.uid] task_project_label_ids_set = set( - map(lambda prediction: prediction['id'], task_project['labels'])) + map(lambda prediction: prediction["id"], task_project["labels"]) + ) assert label_id in task_project_label_ids_set - assert task_project['project_details']['workflow_status'] == 'IN_REVIEW' + assert task_project["project_details"]["workflow_status"] == "IN_REVIEW" # TODO: Add back in when we have a way to test this # if include_performance_details: @@ -124,9 +138,12 @@ def test_project_export_v2_date_filters(client, export_v2_test_helpers, export_v2_test_helpers.run_project_export_v2_task(project, filters=filters) -def test_project_export_v2_with_iso_date_filters(client, export_v2_test_helpers, - configured_project_with_label, - wait_for_data_row_processing): +def test_project_export_v2_with_iso_date_filters( + client, + export_v2_test_helpers, + configured_project_with_label, + wait_for_data_row_processing, +): project, _, data_row, label = configured_project_with_label data_row = wait_for_data_row_processing(client, data_row) label_id = label.uid @@ -135,33 +152,40 @@ def test_project_export_v2_with_iso_date_filters(client, export_v2_test_helpers, filters = { "last_activity_at": [ - "2000-01-01T00:00:00+0230", "2050-01-01T00:00:00+0230" + "2000-01-01T00:00:00+0230", + "2050-01-01T00:00:00+0230", ], "label_created_at": [ - "2000-01-01T00:00:00+0230", "2050-01-01T00:00:00+0230" - ] + "2000-01-01T00:00:00+0230", + "2050-01-01T00:00:00+0230", + ], } task_results = export_v2_test_helpers.run_project_export_v2_task( - project, task_name=task_name, filters=filters) - assert label_id == task_results[0]['projects'][ - project.uid]['labels'][0]['id'] + project, task_name=task_name, filters=filters + ) + assert ( + label_id == task_results[0]["projects"][project.uid]["labels"][0]["id"] + ) filters = {"last_activity_at": [None, "2050-01-01T00:00:00+0230"]} task_results = export_v2_test_helpers.run_project_export_v2_task( - project, task_name=task_name, filters=filters) - assert label_id == task_results[0]['projects'][ - project.uid]['labels'][0]['id'] + project, task_name=task_name, filters=filters + ) + assert ( + label_id == task_results[0]["projects"][project.uid]["labels"][0]["id"] + ) filters = {"label_created_at": ["2050-01-01T00:00:00+0230", None]} task_results = export_v2_test_helpers.run_project_export_v2_task( - project, task_name=task_name, filters=filters) + project, task_name=task_name, filters=filters + ) assert len(task_results) == 0 @pytest.mark.parametrize("data_rows", [3], indirect=True) def test_project_export_v2_datarows_filter( - export_v2_test_helpers, - configured_batch_project_with_multiple_datarows): + export_v2_test_helpers, configured_batch_project_with_multiple_datarows +): project, _, data_rows = configured_batch_project_with_multiple_datarows data_row_ids = [dr.uid for dr in data_rows] @@ -170,39 +194,47 @@ def test_project_export_v2_datarows_filter( filters = { "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "data_row_ids": data_row_ids[:datarow_filter_size] + "data_row_ids": data_row_ids[:datarow_filter_size], } params = {"data_row_details": True, "media_type_override": MediaType.Image} task_results = export_v2_test_helpers.run_project_export_v2_task( - project, filters=filters, params=params) + project, filters=filters, params=params + ) # only 2 datarows should be exported assert len(task_results) == datarow_filter_size # only filtered datarows should be exported - assert set([dr['data_row']['id'] for dr in task_results - ]) == set(data_row_ids[:datarow_filter_size]) + assert set([dr["data_row"]["id"] for dr in task_results]) == set( + data_row_ids[:datarow_filter_size] + ) global_keys = [dr.global_key for dr in data_rows] filters = { "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "global_keys": global_keys[:datarow_filter_size] + "global_keys": global_keys[:datarow_filter_size], } params = {"data_row_details": True, "media_type_override": MediaType.Image} task_results = export_v2_test_helpers.run_project_export_v2_task( - project, filters=filters, params=params) + project, filters=filters, params=params + ) # only 2 datarows should be exported assert len(task_results) == datarow_filter_size # only filtered datarows should be exported - assert set([dr['data_row']['global_key'] for dr in task_results - ]) == set(global_keys[:datarow_filter_size]) + assert set([dr["data_row"]["global_key"] for dr in task_results]) == set( + global_keys[:datarow_filter_size] + ) def test_batch_project_export_v2( - configured_batch_project_with_label: Tuple[Project, Dataset, DataRow, - Label], - export_v2_test_helpers, dataset: Dataset, image_url: str): + configured_batch_project_with_label: Tuple[ + Project, Dataset, DataRow, Label + ], + export_v2_test_helpers, + dataset: Dataset, + image_url: str, +): project, dataset, *_ = configured_batch_project_with_label batch = list(project.batches())[0] @@ -214,23 +246,24 @@ def test_batch_project_export_v2( params = { "include_performance_details": True, "include_labels": True, - "media_type_override": MediaType.Image + "media_type_override": MediaType.Image, } task_name = "test_batch_export_v2" - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image" - }, - ] * 2) + task = dataset.create_data_rows( + [ + {"row_data": image_url, "external_id": "my-image"}, + ] + * 2 + ) task.wait_till_done() data_rows = [dr.uid for dr in list(dataset.export_data_rows())] - batch_one = f'batch one {uuid.uuid4()}' + batch_one = f"batch one {uuid.uuid4()}" # This test creates two batches, only one batch should be exporter # Creatin second batch that will not be used in the export due to the filter: batch_id project.create_batch(batch_one, data_rows) task_results = export_v2_test_helpers.run_project_export_v2_task( - project, task_name=task_name, filters=filters, params=params) - assert (batch.size == len(task_results)) + project, task_name=task_name, filters=filters, params=params + ) + assert batch.size == len(task_results) diff --git a/libs/labelbox/tests/data/export/legacy/test_export_slice.py b/libs/labelbox/tests/data/export/legacy/test_export_slice.py index 2caa6b227..3d1fb7898 100644 --- a/libs/labelbox/tests/data/export/legacy/test_export_slice.py +++ b/libs/labelbox/tests/data/export/legacy/test_export_slice.py @@ -2,15 +2,15 @@ @pytest.mark.skip( - 'Skipping until we have a way to create slices programatically') + "Skipping until we have a way to create slices programatically" +) def test_export_v2_slice(client): # Since we don't have CRUD for slices, we'll just use the one that's already there SLICE_ID = "clk04g1e4000ryb0rgsvy1dty" slice = client.get_catalog_slice(SLICE_ID) - task = slice.export_v2(params={ - "performance_details": False, - "label_details": True - }) + task = slice.export_v2( + params={"performance_details": False, "label_details": True} + ) task.wait_till_done() assert task.status == "COMPLETE" assert task.errors is None diff --git a/libs/labelbox/tests/data/export/legacy/test_export_video.py b/libs/labelbox/tests/data/export/legacy/test_export_video.py index 3a0cb4149..75a57eca9 100644 --- a/libs/labelbox/tests/data/export/legacy/test_export_video.py +++ b/libs/labelbox/tests/data/export/legacy/test_export_video.py @@ -25,7 +25,6 @@ def test_export_v2_video( bbox_video_annotation_objects, rand_gen, ): - project = configured_project_without_data_rows project_id = project.uid labels = [] @@ -34,17 +33,20 @@ def test_export_v2_video( project.create_batch( rand_gen(str), data_row_uids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) + 5, # priority between 1(Highest) - 5(lowest) ) for data_row_uid in data_row_uids: labels = [ - lb_types.Label(data=VideoData(uid=data_row_uid), - annotations=bbox_video_annotation_objects) + lb_types.Label( + data=VideoData(uid=data_row_uid), + annotations=bbox_video_annotation_objects, + ) ] label_import = lb.LabelImport.create_from_objects( - client, project_id, f'test-import-{project_id}', labels) + client, project_id, f"test-import-{project_id}", labels + ) label_import.wait_until_done() assert label_import.state == AnnotationImportState.FINISHED @@ -53,13 +55,14 @@ def test_export_v2_video( num_retries = 5 task = None - while (num_retries > 0): + while num_retries > 0: task = project.export_v2( params={ "performance_details": False, "label_details": True, - "interpolated_frames": True - }) + "interpolated_frames": True, + } + ) task.wait_till_done() assert task.status == "COMPLETE" assert task.errors is None @@ -70,129 +73,135 @@ def test_export_v2_video( break export_data = task.result - data_row_export = export_data[0]['data_row'] - assert data_row_export['global_key'] == video_data_row['global_key'] - assert data_row_export['row_data'] == video_data_row['row_data'] - assert export_data[0]['media_attributes']['mime_type'] == 'video/mp4' - assert export_data[0]['media_attributes'][ - 'frame_rate'] == 10 # as per the video_data fixture - assert export_data[0]['media_attributes'][ - 'frame_count'] == 100 # as per the video_data fixture + data_row_export = export_data[0]["data_row"] + assert data_row_export["global_key"] == video_data_row["global_key"] + assert data_row_export["row_data"] == video_data_row["row_data"] + assert export_data[0]["media_attributes"]["mime_type"] == "video/mp4" + assert ( + export_data[0]["media_attributes"]["frame_rate"] == 10 + ) # as per the video_data fixture + assert ( + export_data[0]["media_attributes"]["frame_count"] == 100 + ) # as per the video_data fixture expected_export_label = { - 'label_kind': 'Video', - 'version': '1.0.0', - 'id': 'clgjnpysl000xi3zxtnp29fug', - 'label_details': { - 'created_at': '2023-04-16T17:04:23+00:00', - 'updated_at': '2023-04-16T17:04:23+00:00', - 'created_by': 'vbrodsky@labelbox.com', - 'content_last_updated_at': '2023-04-16T17:04:23+00:00', - 'reviews': [] + "label_kind": "Video", + "version": "1.0.0", + "id": "clgjnpysl000xi3zxtnp29fug", + "label_details": { + "created_at": "2023-04-16T17:04:23+00:00", + "updated_at": "2023-04-16T17:04:23+00:00", + "created_by": "vbrodsky@labelbox.com", + "content_last_updated_at": "2023-04-16T17:04:23+00:00", + "reviews": [], }, - 'annotations': { - 'frames': { - '13': { - 'objects': { - 'clgjnpyse000ui3zx6fr1d880': { - 'feature_id': 'clgjnpyse000ui3zx6fr1d880', - 'name': 'bbox', - 'annotation_kind': 'VideoBoundingBox', - 'classifications': [{ - 'feature_id': 'clgjnpyse000vi3zxtgtfh01y', - 'name': 'nested', - 'radio_answer': { - 'feature_id': 'clgjnpyse000wi3zxnxgv53ps', - 'name': 'radio_option_1', - 'classifications': [] + "annotations": { + "frames": { + "13": { + "objects": { + "clgjnpyse000ui3zx6fr1d880": { + "feature_id": "clgjnpyse000ui3zx6fr1d880", + "name": "bbox", + "annotation_kind": "VideoBoundingBox", + "classifications": [ + { + "feature_id": "clgjnpyse000vi3zxtgtfh01y", + "name": "nested", + "radio_answer": { + "feature_id": "clgjnpyse000wi3zxnxgv53ps", + "name": "radio_option_1", + "classifications": [], + }, } - }], - 'bounding_box': { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - } + ], + "bounding_box": { + "top": 98.0, + "left": 146.0, + "height": 243.0, + "width": 236.0, + }, } }, - 'classifications': [] + "classifications": [], }, - '18': { - 'objects': { - 'clgjnpyse000ui3zx6fr1d880': { - 'feature_id': 'clgjnpyse000ui3zx6fr1d880', - 'name': 'bbox', - 'annotation_kind': 'VideoBoundingBox', - 'classifications': [{ - 'feature_id': 'clgjnpyse000vi3zxtgtfh01y', - 'name': 'nested', - 'radio_answer': { - 'feature_id': 'clgjnpyse000wi3zxnxgv53ps', - 'name': 'radio_option_1', - 'classifications': [] + "18": { + "objects": { + "clgjnpyse000ui3zx6fr1d880": { + "feature_id": "clgjnpyse000ui3zx6fr1d880", + "name": "bbox", + "annotation_kind": "VideoBoundingBox", + "classifications": [ + { + "feature_id": "clgjnpyse000vi3zxtgtfh01y", + "name": "nested", + "radio_answer": { + "feature_id": "clgjnpyse000wi3zxnxgv53ps", + "name": "radio_option_1", + "classifications": [], + }, } - }], - 'bounding_box': { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - } + ], + "bounding_box": { + "top": 98.0, + "left": 146.0, + "height": 243.0, + "width": 236.0, + }, } }, - 'classifications': [] + "classifications": [], }, - '19': { - 'objects': { - 'clgjnpyse000ui3zx6fr1d880': { - 'feature_id': 'clgjnpyse000ui3zx6fr1d880', - 'name': 'bbox', - 'annotation_kind': 'VideoBoundingBox', - 'classifications': [], - 'bounding_box': { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - } + "19": { + "objects": { + "clgjnpyse000ui3zx6fr1d880": { + "feature_id": "clgjnpyse000ui3zx6fr1d880", + "name": "bbox", + "annotation_kind": "VideoBoundingBox", + "classifications": [], + "bounding_box": { + "top": 98.0, + "left": 146.0, + "height": 243.0, + "width": 236.0, + }, } }, - 'classifications': [] - } - }, - 'segments': { - 'clgjnpyse000ui3zx6fr1d880': [[13, 13], [18, 19]] + "classifications": [], + }, }, - 'key_frame_feature_map': { - 'clgjnpyse000ui3zx6fr1d880': { - '13': True, - '18': False, - '19': True + "segments": {"clgjnpyse000ui3zx6fr1d880": [[13, 13], [18, 19]]}, + "key_frame_feature_map": { + "clgjnpyse000ui3zx6fr1d880": { + "13": True, + "18": False, + "19": True, } }, - 'classifications': [] - } + "classifications": [], + }, } - project_export_labels = export_data[0]['projects'][project_id]['labels'] - assert (len(project_export_labels) == len(labels) - ) #note we create 1 label per data row, 1 data row so 1 label + project_export_labels = export_data[0]["projects"][project_id]["labels"] + assert len(project_export_labels) == len( + labels + ) # note we create 1 label per data row, 1 data row so 1 label export_label = project_export_labels[0] - assert (export_label['label_kind']) == 'Video' + assert (export_label["label_kind"]) == "Video" - assert (export_label['label_details'].keys() - ) == expected_export_label['label_details'].keys() + assert (export_label["label_details"].keys()) == expected_export_label[ + "label_details" + ].keys() expected_frames_ids = [ vannotation.frame for vannotation in bbox_video_annotation_objects ] - export_annotations = export_label['annotations'] - export_frames = export_annotations['frames'] + export_annotations = export_label["annotations"] + export_frames = export_annotations["frames"] export_frames_ids = [int(frame_id) for frame_id in export_frames.keys()] all_frames_exported = [] for value in expected_frames_ids: # note need to understand why we are exporting more frames than we created if value not in export_frames_ids: all_frames_exported.append(value) - assert (len(all_frames_exported) == 0) + assert len(all_frames_exported) == 0 # BEGINNING OF THE VIDEO INTERPOLATION ASSERTIONS first_frame_id = bbox_video_annotation_objects[0].frame @@ -203,42 +212,50 @@ def test_export_v2_video( assert export_frames_ids == expected_frame_ids - exported_objects_dict = export_frames[str(first_frame_id)]['objects'] + exported_objects_dict = export_frames[str(first_frame_id)]["objects"] # Get the label ID first_exported_label_id = list(exported_objects_dict.keys())[0] # Since the bounding box moves to the right, the interpolated frame content should start a little bit more far to the right - assert export_frames[str(first_frame_id + 1)]['objects'][ - first_exported_label_id]['bounding_box']['left'] > export_frames[ - str(first_frame_id - )]['objects'][first_exported_label_id]['bounding_box']['left'] + assert ( + export_frames[str(first_frame_id + 1)]["objects"][ + first_exported_label_id + ]["bounding_box"]["left"] + > export_frames[str(first_frame_id)]["objects"][ + first_exported_label_id + ]["bounding_box"]["left"] + ) # But it shouldn't be further than the last frame - assert export_frames[str(first_frame_id + 1)]['objects'][ - first_exported_label_id]['bounding_box']['left'] < export_frames[ - str(last_frame_id - )]['objects'][first_exported_label_id]['bounding_box']['left'] + assert ( + export_frames[str(first_frame_id + 1)]["objects"][ + first_exported_label_id + ]["bounding_box"]["left"] + < export_frames[str(last_frame_id)]["objects"][first_exported_label_id][ + "bounding_box" + ]["left"] + ) # END OF THE VIDEO INTERPOLATION ASSERTIONS - frame_with_nested_classifications = export_frames['13'] + frame_with_nested_classifications = export_frames["13"] annotation = None - for _, a in frame_with_nested_classifications['objects'].items(): - if a['name'] == 'bbox': + for _, a in frame_with_nested_classifications["objects"].items(): + if a["name"] == "bbox": annotation = a break - assert (annotation is not None) - assert (annotation['annotation_kind'] == 'VideoBoundingBox') - assert (annotation['classifications']) - assert (annotation['bounding_box'] == { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - }) - classifications = annotation['classifications'] - classification = classifications[0]['radio_answer'] - assert (classification['name'] == 'radio_option_1') - subclassifications = classification['classifications'] + assert annotation is not None + assert annotation["annotation_kind"] == "VideoBoundingBox" + assert annotation["classifications"] + assert annotation["bounding_box"] == { + "top": 98.0, + "left": 146.0, + "height": 243.0, + "width": 236.0, + } + classifications = annotation["classifications"] + classification = classifications[0]["radio_answer"] + assert classification["name"] == "radio_option_1" + subclassifications = classification["classifications"] # NOTE predictions services does not support nested classifications at the moment, see # https://labelbox.atlassian.net/browse/AL-5588 - assert (len(subclassifications) == 0) + assert len(subclassifications) == 0 diff --git a/libs/labelbox/tests/data/export/legacy/test_legacy_export.py b/libs/labelbox/tests/data/export/legacy/test_legacy_export.py index 31ae8ca91..93b803f7f 100644 --- a/libs/labelbox/tests/data/export/legacy/test_legacy_export.py +++ b/libs/labelbox/tests/data/export/legacy/test_legacy_export.py @@ -13,8 +13,10 @@ @pytest.mark.skip(reason="broken export v1 api, to be retired soon") def test_export_annotations_nested_checklist( - client, configured_project_with_complex_ontology, - wait_for_data_row_processing): + client, + configured_project_with_complex_ontology, + wait_for_data_row_processing, +): project, data_row = configured_project_with_complex_ontology data_row = wait_for_data_row_processing(client, data_row) ontology = project.ontology().normalized @@ -22,43 +24,44 @@ def test_export_annotations_nested_checklist( tool = ontology["tools"][0] nested_check = [ - subc for subc in tool["classifications"] + subc + for subc in tool["classifications"] if subc["name"] == "test-checklist-class" ][0] - data = [{ - "uuid": - str(uuid.uuid4()), - "schemaId": - tool['featureSchemaId'], - "dataRow": { - "id": data_row.uid - }, - "bbox": { - "top": 20, - "left": 20, - "height": 50, - "width": 50 - }, - "classifications": [{ - "schemaId": - nested_check["featureSchemaId"], - "answers": [ - { - "schemaId": nested_check["options"][0]["featureSchemaId"] - }, + data = [ + { + "uuid": str(uuid.uuid4()), + "schemaId": tool["featureSchemaId"], + "dataRow": {"id": data_row.uid}, + "bbox": {"top": 20, "left": 20, "height": 50, "width": 50}, + "classifications": [ { - "schemaId": nested_check["options"][1]["featureSchemaId"] - }, - ] - }] - }] - task = LabelImport.create_from_objects(client, project.uid, - f'label-import-{uuid.uuid4()}', data) + "schemaId": nested_check["featureSchemaId"], + "answers": [ + { + "schemaId": nested_check["options"][0][ + "featureSchemaId" + ] + }, + { + "schemaId": nested_check["options"][1][ + "featureSchemaId" + ] + }, + ], + } + ], + } + ] + task = LabelImport.create_from_objects( + client, project.uid, f"label-import-{uuid.uuid4()}", data + ) task.wait_until_done() labels = project.label_generator() object_annotation = [ - annot for annot in next(labels).annotations + annot + for annot in next(labels).annotations if isinstance(annot, ObjectAnnotation) ][0] @@ -67,29 +70,26 @@ def test_export_annotations_nested_checklist( @pytest.mark.skip(reason="broken export v1 api, to be retired soon") -def test_export_filtered_dates(client, - configured_project_with_complex_ontology): +def test_export_filtered_dates( + client, configured_project_with_complex_ontology +): project, data_row = configured_project_with_complex_ontology ontology = project.ontology().normalized tool = ontology["tools"][0] - data = [{ - "uuid": str(uuid.uuid4()), - "schemaId": tool['featureSchemaId'], - "dataRow": { - "id": data_row.uid - }, - "bbox": { - "top": 20, - "left": 20, - "height": 50, - "width": 50 + data = [ + { + "uuid": str(uuid.uuid4()), + "schemaId": tool["featureSchemaId"], + "dataRow": {"id": data_row.uid}, + "bbox": {"top": 20, "left": 20, "height": 50, "width": 50}, } - }] + ] - task = LabelImport.create_from_objects(client, project.uid, - f'label-import-{uuid.uuid4()}', data) + task = LabelImport.create_from_objects( + client, project.uid, f"label-import-{uuid.uuid4()}", data + ) task.wait_until_done() regular_export = project.export_labels(download=True) @@ -99,39 +99,37 @@ def test_export_filtered_dates(client, assert len(filtered_export) == 1 filtered_export_with_time = project.export_labels( - download=True, start="2020-01-01 00:00:01") + download=True, start="2020-01-01 00:00:01" + ) assert len(filtered_export_with_time) == 1 - empty_export = project.export_labels(download=True, - start="2020-01-01", - end="2020-01-02") + empty_export = project.export_labels( + download=True, start="2020-01-01", end="2020-01-02" + ) assert len(empty_export) == 0 @pytest.mark.skip(reason="broken export v1 api, to be retired soon") -def test_export_filtered_activity(client, - configured_project_with_complex_ontology): +def test_export_filtered_activity( + client, configured_project_with_complex_ontology +): project, data_row = configured_project_with_complex_ontology ontology = project.ontology().normalized tool = ontology["tools"][0] - data = [{ - "uuid": str(uuid.uuid4()), - "schemaId": tool['featureSchemaId'], - "dataRow": { - "id": data_row.uid - }, - "bbox": { - "top": 20, - "left": 20, - "height": 50, - "width": 50 + data = [ + { + "uuid": str(uuid.uuid4()), + "schemaId": tool["featureSchemaId"], + "dataRow": {"id": data_row.uid}, + "bbox": {"top": 20, "left": 20, "height": 50, "width": 50}, } - }] + ] - task = LabelImport.create_from_objects(client, project.uid, - f'label-import-{uuid.uuid4()}', data) + task = LabelImport.create_from_objects( + client, project.uid, f"label-import-{uuid.uuid4()}", data + ) task.wait_until_done() regular_export = project.export_labels(download=True) @@ -140,35 +138,41 @@ def test_export_filtered_activity(client, filtered_export = project.export_labels( download=True, last_activity_start="2020-01-01", - last_activity_end=(datetime.datetime.now() + - datetime.timedelta(days=2)).strftime("%Y-%m-%d")) + last_activity_end=( + datetime.datetime.now() + datetime.timedelta(days=2) + ).strftime("%Y-%m-%d"), + ) assert len(filtered_export) == 1 filtered_export_with_time = project.export_labels( - download=True, last_activity_start="2020-01-01 00:00:01") + download=True, last_activity_start="2020-01-01 00:00:01" + ) assert len(filtered_export_with_time) == 1 empty_export = project.export_labels( download=True, - last_activity_start=(datetime.datetime.now() + - datetime.timedelta(days=2)).strftime("%Y-%m-%d"), + last_activity_start=( + datetime.datetime.now() + datetime.timedelta(days=2) + ).strftime("%Y-%m-%d"), ) empty_export = project.export_labels( download=True, - last_activity_end=(datetime.datetime.now() - - datetime.timedelta(days=1)).strftime("%Y-%m-%d")) + last_activity_end=( + datetime.datetime.now() - datetime.timedelta(days=1) + ).strftime("%Y-%m-%d"), + ) assert len(empty_export) == 0 def test_export_data_rows(project: Project, dataset: Dataset): n_data_rows = 2 - task = dataset.create_data_rows([ - { - "row_data": IMAGE_URL, - "external_id": "my-image" - }, - ] * n_data_rows) + task = dataset.create_data_rows( + [ + {"row_data": IMAGE_URL, "external_id": "my-image"}, + ] + * n_data_rows + ) task.wait_till_done() data_rows = [dr.uid for dr in list(dataset.export_data_rows())] @@ -196,9 +200,9 @@ def test_label_export(configured_project_with_label): exported_labels_url = project.export_labels() assert exported_labels_url is not None exported_labels = requests.get(exported_labels_url) - labels = [example['ID'] for example in exported_labels.json()] + labels = [example["ID"] for example in exported_labels.json()] assert labels[0] == label_id - #TODO: Add test for bulk export back. + # TODO: Add test for bulk export back. # The new exporter doesn't work with the create_label mutation @@ -233,11 +237,12 @@ def test_dataset_export(dataset, image_url): @pytest.mark.skip(reason="broken export v1 api, to be retired soon") def test_data_row_export_with_empty_media_attributes( - client, configured_project_with_label, wait_for_data_row_processing): + client, configured_project_with_label, wait_for_data_row_processing +): project, _, data_row, _ = configured_project_with_label data_row = wait_for_data_row_processing(client, data_row) labels = list(project.label_generator()) - assert len( - labels - ) == 1, "Label export job unexpectedly returned an empty result set`" + assert ( + len(labels) == 1 + ), "Label export job unexpectedly returned an empty result set`" assert labels[0].data.media_attributes == {} diff --git a/libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py index 0d98d8a89..3e4efbc46 100644 --- a/libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py +++ b/libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py @@ -7,9 +7,9 @@ class TestExportDataRow: - - def test_with_data_row_object(self, client, data_row, - wait_for_data_row_processing): + def test_with_data_row_object( + self, client, data_row, wait_for_data_row_processing + ): data_row = wait_for_data_row_processing(client, data_row) time.sleep(7) # temp fix for ES indexing delay export_task = DataRow.export( @@ -22,14 +22,20 @@ def test_with_data_row_object(self, client, data_row, assert isinstance(export_task, ExportTask) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) == 1 - assert (json.loads(list(export_task.get_stream())[0].json_str) - ["data_row"]["id"] == data_row.uid) - - def test_with_data_row_object_buffered(self, client, data_row, - wait_for_data_row_processing): + assert ( + json.loads(list(export_task.get_stream())[0].json_str)["data_row"][ + "id" + ] + == data_row.uid + ) + + def test_with_data_row_object_buffered( + self, client, data_row, wait_for_data_row_processing + ): data_row = wait_for_data_row_processing(client, data_row) time.sleep(7) # temp fix for ES indexing delay export_task = DataRow.export( @@ -42,30 +48,42 @@ def test_with_data_row_object_buffered(self, client, data_row, assert isinstance(export_task, ExportTask) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) == 1 - assert list(export_task.get_buffered_stream())[0].json["data_row"]["id"] == data_row.uid + assert ( + list(export_task.get_buffered_stream())[0].json["data_row"]["id"] + == data_row.uid + ) def test_with_id(self, client, data_row, wait_for_data_row_processing): data_row = wait_for_data_row_processing(client, data_row) time.sleep(7) # temp fix for ES indexing delay - export_task = DataRow.export(client=client, - data_rows=[data_row.uid], - task_name="TestExportDataRow:test_with_id") + export_task = DataRow.export( + client=client, + data_rows=[data_row.uid], + task_name="TestExportDataRow:test_with_id", + ) export_task.wait_till_done() assert export_task.status == "COMPLETE" assert isinstance(export_task, ExportTask) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) == 1 - assert (json.loads(list(export_task.get_stream())[0].json_str) - ["data_row"]["id"] == data_row.uid) + assert ( + json.loads(list(export_task.get_stream())[0].json_str)["data_row"][ + "id" + ] + == data_row.uid + ) - def test_with_global_key(self, client, data_row, - wait_for_data_row_processing): + def test_with_global_key( + self, client, data_row, wait_for_data_row_processing + ): data_row = wait_for_data_row_processing(client, data_row) time.sleep(7) # temp fix for ES indexing delay export_task = DataRow.export( @@ -78,11 +96,16 @@ def test_with_global_key(self, client, data_row, assert isinstance(export_task, ExportTask) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) == 1 - assert (json.loads(list(export_task.get_stream())[0].json_str) - ["data_row"]["id"] == data_row.uid) + assert ( + json.loads(list(export_task.get_stream())[0].json_str)["data_row"][ + "id" + ] + == data_row.uid + ) def test_with_invalid_id(self, client): export_task = DataRow.export( @@ -95,7 +118,10 @@ def test_with_invalid_id(self, client): assert isinstance(export_task, ExportTask) assert export_task.has_result() is False assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) is None - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) is None + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) + is None + ) + assert ( + export_task.get_total_lines(stream_type=StreamType.RESULT) is None + ) diff --git a/libs/labelbox/tests/data/export/streamable/test_export_dataset_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_dataset_streamable.py index e31f17c44..57f617a00 100644 --- a/libs/labelbox/tests/data/export/streamable/test_export_dataset_streamable.py +++ b/libs/labelbox/tests/data/export/streamable/test_export_dataset_streamable.py @@ -6,7 +6,6 @@ class TestExportDataset: - @pytest.mark.parametrize("data_rows", [3], indirect=True) def test_export(self, dataset, data_rows): expected_data_row_ids = [dr.uid for dr in data_rows] @@ -18,61 +17,82 @@ def test_export(self, dataset, data_rows): assert isinstance(export_task, ExportTask) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == len(expected_data_row_ids) + stream_type=StreamType.RESULT + ) == len(expected_data_row_ids) data_row_ids = list( - map(lambda x: json.loads(x.json_str)["data_row"]["id"], - export_task.get_stream())) + map( + lambda x: json.loads(x.json_str)["data_row"]["id"], + export_task.get_stream(), + ) + ) assert data_row_ids.sort() == expected_data_row_ids.sort() @pytest.mark.parametrize("data_rows", [3], indirect=True) def test_with_data_row_filter(self, dataset, data_rows): datarow_filter_size = 3 - expected_data_row_ids = [dr.uid for dr in data_rows - ][:datarow_filter_size] + expected_data_row_ids = [dr.uid for dr in data_rows][ + :datarow_filter_size + ] filters = {"data_row_ids": expected_data_row_ids} export_task = dataset.export( filters=filters, - task_name="TestExportDataset:test_with_data_row_filter") + task_name="TestExportDataset:test_with_data_row_filter", + ) export_task.wait_till_done() assert export_task.status == "COMPLETE" assert isinstance(export_task, ExportTask) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == datarow_filter_size + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) + assert ( + export_task.get_total_lines(stream_type=StreamType.RESULT) + == datarow_filter_size + ) data_row_ids = list( - map(lambda x: json.loads(x.json_str)["data_row"]["id"], - export_task.get_stream())) + map( + lambda x: json.loads(x.json_str)["data_row"]["id"], + export_task.get_stream(), + ) + ) assert data_row_ids.sort() == expected_data_row_ids.sort() @pytest.mark.parametrize("data_rows", [3], indirect=True) def test_with_global_key_filter(self, dataset, data_rows): datarow_filter_size = 2 - expected_global_keys = [dr.global_key for dr in data_rows - ][:datarow_filter_size] + expected_global_keys = [dr.global_key for dr in data_rows][ + :datarow_filter_size + ] filters = {"global_keys": expected_global_keys} export_task = dataset.export( filters=filters, - task_name="TestExportDataset:test_with_global_key_filter") + task_name="TestExportDataset:test_with_global_key_filter", + ) export_task.wait_till_done() assert export_task.status == "COMPLETE" assert isinstance(export_task, ExportTask) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == datarow_filter_size + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) + assert ( + export_task.get_total_lines(stream_type=StreamType.RESULT) + == datarow_filter_size + ) global_keys = list( - map(lambda x: json.loads(x.json_str)["data_row"]["global_key"], - export_task.get_stream())) + map( + lambda x: json.loads(x.json_str)["data_row"]["global_key"], + export_task.get_stream(), + ) + ) assert global_keys.sort() == expected_global_keys.sort() diff --git a/libs/labelbox/tests/data/export/streamable/test_export_embeddings_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_embeddings_streamable.py index b0c683486..071acbb5b 100644 --- a/libs/labelbox/tests/data/export/streamable/test_export_embeddings_streamable.py +++ b/libs/labelbox/tests/data/export/streamable/test_export_embeddings_streamable.py @@ -5,13 +5,15 @@ class TestExportEmbeddings: - - def test_export_embeddings_precomputed(self, client, dataset, environ, - image_url): - data_row_specs = [{ - "row_data": image_url, - "external_id": "image", - }] + def test_export_embeddings_precomputed( + self, client, dataset, environ, image_url + ): + data_row_specs = [ + { + "row_data": image_url, + "external_id": "image", + } + ] task = dataset.create_data_rows(data_row_specs) task.wait_till_done() export_task = dataset.export(params={"embeddings": True}) @@ -21,30 +23,42 @@ def test_export_embeddings_precomputed(self, client, dataset, environ, assert export_task.has_errors() is False results = [] - export_task.get_stream(converter=JsonConverter(), - stream_type=StreamType.RESULT).start( - stream_handler=lambda output: results.append( - json.loads(output.json_str))) + export_task.get_stream( + converter=JsonConverter(), stream_type=StreamType.RESULT + ).start( + stream_handler=lambda output: results.append( + json.loads(output.json_str) + ) + ) assert len(results) == len(data_row_specs) result = results[0] assert "embeddings" in result assert len(result["embeddings"]) > 0 - assert result["embeddings"][0][ - "name"] == "Image Embedding V2 (CLIP ViT-B/32)" + assert ( + result["embeddings"][0]["name"] + == "Image Embedding V2 (CLIP ViT-B/32)" + ) assert len(result["embeddings"][0]["values"]) == 1 - def test_export_embeddings_custom(self, client, dataset, image_url, - embedding): + def test_export_embeddings_custom( + self, client, dataset, image_url, embedding + ): vector = [random.uniform(1.0, 2.0) for _ in range(embedding.dims)] - import_task = dataset.create_data_rows([{ - "row_data": image_url, - "embeddings": [{ - "embedding_id": embedding.id, - "vector": vector, - }], - }]) + import_task = dataset.create_data_rows( + [ + { + "row_data": image_url, + "embeddings": [ + { + "embedding_id": embedding.id, + "vector": vector, + } + ], + } + ] + ) import_task.wait_till_done() assert import_task.status == "COMPLETE" @@ -55,15 +69,19 @@ def test_export_embeddings_custom(self, client, dataset, image_url, assert export_task.has_errors() is False results = [] - export_task.get_stream(converter=JsonConverter(), - stream_type=StreamType.RESULT).start( - stream_handler=lambda output: results.append( - json.loads(output.json_str))) + export_task.get_stream( + converter=JsonConverter(), stream_type=StreamType.RESULT + ).start( + stream_handler=lambda output: results.append( + json.loads(output.json_str) + ) + ) assert len(results) == 1 assert "embeddings" in results[0] - assert (len(results[0]["embeddings"]) - >= 1) # should at least contain the custom embedding + assert ( + len(results[0]["embeddings"]) >= 1 + ) # should at least contain the custom embedding for emb in results[0]["embeddings"]: if emb["id"] == embedding.id: assert emb["name"] == embedding.name diff --git a/libs/labelbox/tests/data/export/streamable/test_export_model_run_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_model_run_streamable.py index 0d1244660..ada493fc3 100644 --- a/libs/labelbox/tests/data/export/streamable/test_export_model_run_streamable.py +++ b/libs/labelbox/tests/data/export/streamable/test_export_model_run_streamable.py @@ -5,7 +5,6 @@ class TestExportModelRun: - def test_export(self, model_run_with_data_rows): model_run, labels = model_run_with_data_rows label_ids = [label.uid for label in labels] @@ -21,22 +20,31 @@ def test_export(self, model_run_with_data_rows): assert isinstance(export_task, ExportTask) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == len(expected_data_rows) + stream_type=StreamType.RESULT + ) == len(expected_data_rows) for data in export_task.get_stream(): obj = json.loads(data.json_str) - assert "media_attributes" in obj and obj[ - "media_attributes"] is not None + assert ( + "media_attributes" in obj + and obj["media_attributes"] is not None + ) exported_model_run = obj["experiments"][model_run.model_id]["runs"][ - model_run.uid] + model_run.uid + ] task_label_ids_set = set( - map(lambda label: label["id"], exported_model_run["labels"])) + map(lambda label: label["id"], exported_model_run["labels"]) + ) task_prediction_ids_set = set( - map(lambda prediction: prediction["id"], - exported_model_run["predictions"])) + map( + lambda prediction: prediction["id"], + exported_model_run["predictions"], + ) + ) for label_id in task_label_ids_set: assert label_id in label_ids for prediction_id in task_prediction_ids_set: diff --git a/libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py index c29239887..818a0178c 100644 --- a/libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py +++ b/libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py @@ -10,16 +10,12 @@ from labelbox.schema.data_row import DataRow from labelbox.schema.label import Label -IMAGE_URL = ( - "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/potato.jpeg" -) +IMAGE_URL = "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/potato.jpeg" class TestExportProject: - @pytest.fixture def project_export(self): - def _project_export(project, task_name, filters=None, params=None): export_task = project.export( task_name=task_name, @@ -55,8 +51,9 @@ def test_export( export_task = project_export(project, task_name, params=params) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) > 0 for data in export_task.get_stream(): @@ -64,8 +61,8 @@ def test_export( task_media_attributes = obj["media_attributes"] task_project = obj["projects"][project.uid] task_project_label_ids_set = set( - map(lambda prediction: prediction["id"], - task_project["labels"])) + map(lambda prediction: prediction["id"], task_project["labels"]) + ) task_project_details = task_project["project_details"] task_data_row = obj["data_row"] task_data_row_details = task_data_row["details"] @@ -84,8 +81,10 @@ def test_export( assert task_data_row_details["created_by"] is not None # media attributes - assert task_media_attributes[ - "mime_type"] == data_row.media_attributes["mimeType"] + assert ( + task_media_attributes["mime_type"] + == data_row.media_attributes["mimeType"] + ) # project name and details assert task_project["name"] == project.name @@ -93,8 +92,10 @@ def test_export( assert task_project_details["batch_id"] == batch.uid assert task_project_details["batch_name"] == batch.name assert task_project_details["priority"] is not None - assert task_project_details[ - "consensus_expected_label_count"] is not None + assert ( + task_project_details["consensus_expected_label_count"] + is not None + ) assert task_project_details["workflow_history"] is not None # label details @@ -125,27 +126,30 @@ def test_with_date_filters( } task_queues = project.task_queues() review_queue = next( - tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE") + tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE" + ) project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid) - export_task = project_export(project, - task_name, - filters=filters, - params=params) + export_task = project_export( + project, task_name, filters=filters, params=params + ) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) > 0 for data in export_task.get_stream(): obj = json.loads(data.json_str) task_project = obj["projects"][project.uid] task_project_label_ids_set = set( - map(lambda prediction: prediction["id"], - task_project["labels"])) + map(lambda prediction: prediction["id"], task_project["labels"]) + ) assert label_id in task_project_label_ids_set - assert task_project["project_details"][ - "workflow_status"] == "IN_REVIEW" + assert ( + task_project["project_details"]["workflow_status"] + == "IN_REVIEW" + ) def test_with_iso_date_filters( self, @@ -160,21 +164,27 @@ def test_with_iso_date_filters( task_name = "TestExportProject:test_with_iso_date_filters" filters = { "last_activity_at": [ - "2000-01-01T00:00:00+0230", "2050-01-01T00:00:00+0230" + "2000-01-01T00:00:00+0230", + "2050-01-01T00:00:00+0230", ], "label_created_at": [ - "2000-01-01T00:00:00+0230", "2050-01-01T00:00:00+0230" + "2000-01-01T00:00:00+0230", + "2050-01-01T00:00:00+0230", ], } export_task = project_export(project, task_name, filters=filters) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) > 0 - assert (label_id == json.loads( - list(export_task.get_stream())[0].json_str)["projects"][project.uid] - ["labels"][0]["id"]) + assert ( + label_id + == json.loads(list(export_task.get_stream())[0].json_str)[ + "projects" + ][project.uid]["labels"][0]["id"] + ) def test_with_iso_date_filters_no_start_date( self, @@ -191,12 +201,16 @@ def test_with_iso_date_filters_no_start_date( export_task = project_export(project, task_name, filters=filters) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) > 0 - assert (label_id == json.loads( - list(export_task.get_stream())[0].json_str)["projects"][project.uid] - ["labels"][0]["id"]) + assert ( + label_id + == json.loads(list(export_task.get_stream())[0].json_str)[ + "projects" + ][project.uid]["labels"][0]["id"] + ) def test_with_iso_date_filters_and_future_start_date( self, @@ -207,24 +221,30 @@ def test_with_iso_date_filters_and_future_start_date( ): project, _, data_row, _label = configured_project_with_label data_row = wait_for_data_row_processing(client, data_row) - task_name = "TestExportProject:test_with_iso_date_filters_and_future_start_date" + task_name = ( + "TestExportProject:test_with_iso_date_filters_and_future_start_date" + ) filters = {"label_created_at": ["2050-01-01T00:00:00+0230", None]} export_task = project_export(project, task_name, filters=filters) assert export_task.has_result() is False assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) is None - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) is None + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) + is None + ) + assert ( + export_task.get_total_lines(stream_type=StreamType.RESULT) is None + ) @pytest.mark.parametrize("data_rows", [3], indirect=True) def test_with_data_row_filter( - self, configured_batch_project_with_multiple_datarows, - project_export): + self, configured_batch_project_with_multiple_datarows, project_export + ): project, _, data_rows = configured_batch_project_with_multiple_datarows datarow_filter_size = 2 - expected_data_row_ids = [dr.uid for dr in data_rows - ][:datarow_filter_size] + expected_data_row_ids = [dr.uid for dr in data_rows][ + :datarow_filter_size + ] task_name = "TestExportProject:test_with_data_row_filter" filters = { "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], @@ -233,32 +253,38 @@ def test_with_data_row_filter( } params = { "data_row_details": True, - "media_type_override": MediaType.Image + "media_type_override": MediaType.Image, } - export_task = project_export(project, - task_name, - filters=filters, - params=params) + export_task = project_export( + project, task_name, filters=filters, params=params + ) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) # only 2 datarows should be exported - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == datarow_filter_size + assert ( + export_task.get_total_lines(stream_type=StreamType.RESULT) + == datarow_filter_size + ) data_row_ids = list( - map(lambda x: json.loads(x.json_str)["data_row"]["id"], - export_task.get_stream())) + map( + lambda x: json.loads(x.json_str)["data_row"]["id"], + export_task.get_stream(), + ) + ) assert data_row_ids.sort() == expected_data_row_ids.sort() @pytest.mark.parametrize("data_rows", [3], indirect=True) def test_with_global_key_filter( - self, configured_batch_project_with_multiple_datarows, - project_export): + self, configured_batch_project_with_multiple_datarows, project_export + ): project, _, data_rows = configured_batch_project_with_multiple_datarows datarow_filter_size = 2 - expected_global_keys = [dr.global_key for dr in data_rows - ][:datarow_filter_size] + expected_global_keys = [dr.global_key for dr in data_rows][ + :datarow_filter_size + ] task_name = "TestExportProject:test_with_global_key_filter" filters = { "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], @@ -267,28 +293,34 @@ def test_with_global_key_filter( } params = { "data_row_details": True, - "media_type_override": MediaType.Image + "media_type_override": MediaType.Image, } - export_task = project_export(project, - task_name, - filters=filters, - params=params) + export_task = project_export( + project, task_name, filters=filters, params=params + ) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) # only 2 datarows should be exported - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == datarow_filter_size + assert ( + export_task.get_total_lines(stream_type=StreamType.RESULT) + == datarow_filter_size + ) global_keys = list( - map(lambda x: json.loads(x.json_str)["data_row"]["global_key"], - export_task.get_stream())) + map( + lambda x: json.loads(x.json_str)["data_row"]["global_key"], + export_task.get_stream(), + ) + ) assert global_keys.sort() == expected_global_keys.sort() def test_batch( self, - configured_batch_project_with_label: Tuple[Project, Dataset, DataRow, - Label], + configured_batch_project_with_label: Tuple[ + Project, Dataset, DataRow, Label + ], dataset: Dataset, image_url: str, project_export, @@ -306,12 +338,12 @@ def test_batch( "media_type_override": MediaType.Image, } task_name = "TestExportProject:test_batch" - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image" - }, - ] * 2) + task = dataset.create_data_rows( + [ + {"row_data": image_url, "external_id": "my-image"}, + ] + * 2 + ) task.wait_till_done() data_rows = [result["id"] for result in task.result] batch_one = f"batch one {uuid.uuid4()}" @@ -320,13 +352,15 @@ def test_batch( # Creatin second batch that will not be used in the export due to the filter: batch_id project.create_batch(batch_one, data_rows) - export_task = project_export(project, - task_name, - filters=filters, - params=params) + export_task = project_export( + project, task_name, filters=filters, params=params + ) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == batch.size + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) + assert ( + export_task.get_total_lines(stream_type=StreamType.RESULT) + == batch.size + ) diff --git a/libs/labelbox/tests/data/export/streamable/test_export_video_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_video_streamable.py index de32509bd..115194a58 100644 --- a/libs/labelbox/tests/data/export/streamable/test_export_video_streamable.py +++ b/libs/labelbox/tests/data/export/streamable/test_export_video_streamable.py @@ -10,7 +10,6 @@ class TestExportVideo: - @pytest.fixture def user_id(self, client): return client.get_user().uid @@ -41,12 +40,15 @@ def test_export( for data_row_uid in data_row_uids: labels = [ - lb_types.Label(data=VideoData(uid=data_row_uid), - annotations=bbox_video_annotation_objects) + lb_types.Label( + data=VideoData(uid=data_row_uid), + annotations=bbox_video_annotation_objects, + ) ] label_import = lb.LabelImport.create_from_objects( - client, project_id, f"test-import-{project_id}", labels) + client, project_id, f"test-import-{project_id}", labels + ) label_import.wait_until_done() assert label_import.state == AnnotationImportState.FINISHED @@ -65,18 +67,21 @@ def test_export( assert isinstance(export_task, ExportTask) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) export_data = json.loads(list(export_task.get_stream())[0].json_str) data_row_export = export_data["data_row"] assert data_row_export["global_key"] == video_data_row["global_key"] assert data_row_export["row_data"] == video_data_row["row_data"] assert export_data["media_attributes"]["mime_type"] == "video/mp4" - assert export_data["media_attributes"][ - "frame_rate"] == 10 # as per the video_data fixture - assert (export_data["media_attributes"]["frame_count"] == 100 - ) # as per the video_data fixture + assert ( + export_data["media_attributes"]["frame_rate"] == 10 + ) # as per the video_data fixture + assert ( + export_data["media_attributes"]["frame_count"] == 100 + ) # as per the video_data fixture expected_export_label = { "label_kind": "Video", "version": "1.0.0", @@ -96,17 +101,17 @@ def test_export( "feature_id": "clgjnpyse000ui3zx6fr1d880", "name": "bbox", "annotation_kind": "VideoBoundingBox", - "classifications": [{ - "feature_id": "clgjnpyse000vi3zxtgtfh01y", - "name": "nested", - "radio_answer": { - "feature_id": - "clgjnpyse000wi3zxnxgv53ps", - "name": - "radio_option_1", - "classifications": [], - }, - }], + "classifications": [ + { + "feature_id": "clgjnpyse000vi3zxtgtfh01y", + "name": "nested", + "radio_answer": { + "feature_id": "clgjnpyse000wi3zxnxgv53ps", + "name": "radio_option_1", + "classifications": [], + }, + } + ], "bounding_box": { "top": 98.0, "left": 146.0, @@ -123,17 +128,17 @@ def test_export( "feature_id": "clgjnpyse000ui3zx6fr1d880", "name": "bbox", "annotation_kind": "VideoBoundingBox", - "classifications": [{ - "feature_id": "clgjnpyse000vi3zxtgtfh01y", - "name": "nested", - "radio_answer": { - "feature_id": - "clgjnpyse000wi3zxnxgv53ps", - "name": - "radio_option_1", - "classifications": [], - }, - }], + "classifications": [ + { + "feature_id": "clgjnpyse000vi3zxtgtfh01y", + "name": "nested", + "radio_answer": { + "feature_id": "clgjnpyse000wi3zxnxgv53ps", + "name": "radio_option_1", + "classifications": [], + }, + } + ], "bounding_box": { "top": 98.0, "left": 146.0, @@ -162,14 +167,12 @@ def test_export( "classifications": [], }, }, - "segments": { - "clgjnpyse000ui3zx6fr1d880": [[13, 13], [18, 19]] - }, + "segments": {"clgjnpyse000ui3zx6fr1d880": [[13, 13], [18, 19]]}, "key_frame_feature_map": { "clgjnpyse000ui3zx6fr1d880": { "13": True, "18": False, - "19": True + "19": True, } }, "classifications": [], @@ -183,8 +186,9 @@ def test_export( export_label = project_export_labels[0] assert (export_label["label_kind"]) == "Video" - assert (export_label["label_details"].keys() - ) == expected_export_label["label_details"].keys() + assert (export_label["label_details"].keys()) == expected_export_label[ + "label_details" + ].keys() expected_frames_ids = [ vannotation.frame for vannotation in bbox_video_annotation_objects @@ -193,9 +197,7 @@ def test_export( export_frames = export_annotations["frames"] export_frames_ids = [int(frame_id) for frame_id in export_frames.keys()] all_frames_exported = [] - for (value) in ( - expected_frames_ids - ): # note need to understand why we are exporting more frames than we created + for value in expected_frames_ids: # note need to understand why we are exporting more frames than we created if value not in export_frames_ids: all_frames_exported.append(value) assert len(all_frames_exported) == 0 @@ -216,15 +218,23 @@ def test_export( # Since the bounding box moves to the right, the interpolated frame content should start # a little bit more far to the right - assert (export_frames[str(first_frame_id + 1)]["objects"] - [first_exported_label_id]["bounding_box"]["left"] - > export_frames[str(first_frame_id)]["objects"] - [first_exported_label_id]["bounding_box"]["left"]) + assert ( + export_frames[str(first_frame_id + 1)]["objects"][ + first_exported_label_id + ]["bounding_box"]["left"] + > export_frames[str(first_frame_id)]["objects"][ + first_exported_label_id + ]["bounding_box"]["left"] + ) # But it shouldn't be further than the last frame - assert (export_frames[str(first_frame_id + 1)]["objects"] - [first_exported_label_id]["bounding_box"]["left"] - < export_frames[str(last_frame_id)]["objects"] - [first_exported_label_id]["bounding_box"]["left"]) + assert ( + export_frames[str(first_frame_id + 1)]["objects"][ + first_exported_label_id + ]["bounding_box"]["left"] + < export_frames[str(last_frame_id)]["objects"][ + first_exported_label_id + ]["bounding_box"]["left"] + ) # END OF THE VIDEO INTERPOLATION ASSERTIONS frame_with_nested_classifications = export_frames["13"] diff --git a/libs/labelbox/tests/data/metrics/confusion_matrix/conftest.py b/libs/labelbox/tests/data/metrics/confusion_matrix/conftest.py index ce82ff21d..c61c4f1df 100644 --- a/libs/labelbox/tests/data/metrics/confusion_matrix/conftest.py +++ b/libs/labelbox/tests/data/metrics/confusion_matrix/conftest.py @@ -2,30 +2,47 @@ import pytest -from labelbox.data.annotation_types import ClassificationAnnotation, ObjectAnnotation -from labelbox.data.annotation_types import Polygon, Point, Rectangle, Mask, MaskData, Line, Radio, Text, Checklist, ClassificationAnswer +from labelbox.data.annotation_types import ( + ClassificationAnnotation, + ObjectAnnotation, +) +from labelbox.data.annotation_types import ( + Polygon, + Point, + Rectangle, + Mask, + MaskData, + Line, + Radio, + Text, + Checklist, + ClassificationAnswer, +) import numpy as np from labelbox.data.annotation_types.ner import TextEntity class NameSpace(SimpleNamespace): - - def __init__(self, - predictions, - ground_truths, - expected, - expected_without_subclasses=None): + def __init__( + self, + predictions, + ground_truths, + expected, + expected_without_subclasses=None, + ): super(NameSpace, self).__init__( predictions=predictions, ground_truths=ground_truths, expected=expected, - expected_without_subclasses=expected_without_subclasses or expected) + expected_without_subclasses=expected_without_subclasses or expected, + ) def get_radio(name, answer_name): return ClassificationAnnotation( - name=name, value=Radio(answer=ClassificationAnswer(name=answer_name))) + name=name, value=Radio(answer=ClassificationAnswer(name=answer_name)) + ) def get_text(name, text_content): @@ -33,26 +50,33 @@ def get_text(name, text_content): def get_checklist(name, answer_names): - return ClassificationAnnotation(name=name, - value=Radio(answer=[ - ClassificationAnswer(name=answer_name) - for answer_name in answer_names - ])) + return ClassificationAnnotation( + name=name, + value=Radio( + answer=[ + ClassificationAnswer(name=answer_name) + for answer_name in answer_names + ] + ), + ) def get_polygon(name, points, subclasses=None): return ObjectAnnotation( name=name, value=Polygon(points=[Point(x=x, y=y) for x, y in points]), - classifications=[] if subclasses is None else subclasses) + classifications=[] if subclasses is None else subclasses, + ) def get_rectangle(name, start, end, subclasses=None): return ObjectAnnotation( name=name, - value=Rectangle(start=Point(x=start[0], y=start[1]), - end=Point(x=end[0], y=end[1])), - classifications=[] if subclasses is None else subclasses) + value=Rectangle( + start=Point(x=start[0], y=start[1]), end=Point(x=end[0], y=end[1]) + ), + classifications=[] if subclasses is None else subclasses, + ) def get_mask(name, pixels, color=(1, 1, 1), subclasses=None): @@ -62,272 +86,325 @@ def get_mask(name, pixels, color=(1, 1, 1), subclasses=None): return ObjectAnnotation( name=name, value=Mask(mask=MaskData(arr=mask), color=color), - classifications=[] if subclasses is None else subclasses) + classifications=[] if subclasses is None else subclasses, + ) def get_line(name, points, subclasses=None): return ObjectAnnotation( name=name, value=Line(points=[Point(x=x, y=y) for x, y in points]), - classifications=[] if subclasses is None else subclasses) + classifications=[] if subclasses is None else subclasses, + ) def get_point(name, x, y, subclasses=None): return ObjectAnnotation( name=name, value=Point(x=x, y=y), - classifications=[] if subclasses is None else subclasses) + classifications=[] if subclasses is None else subclasses, + ) def get_radio(name, answer_name): return ClassificationAnnotation( - name=name, value=Radio(answer=ClassificationAnswer(name=answer_name))) + name=name, value=Radio(answer=ClassificationAnswer(name=answer_name)) + ) def get_checklist(name, answer_names): - return ClassificationAnnotation(name=name, - value=Checklist(answer=[ - ClassificationAnswer(name=answer_name) - for answer_name in answer_names - ])) + return ClassificationAnnotation( + name=name, + value=Checklist( + answer=[ + ClassificationAnswer(name=answer_name) + for answer_name in answer_names + ] + ), + ) def get_ner(name, start, end, subclasses=None): return ObjectAnnotation( name=name, value=TextEntity(start=start, end=end), - classifications=[] if subclasses is None else subclasses) + classifications=[] if subclasses is None else subclasses, + ) def get_object_pairs(tool_fn, **kwargs): return [ - NameSpace(predictions=[tool_fn("cat", **kwargs)], - ground_truths=[tool_fn("cat", **kwargs)], - expected={'cat': [1, 0, 0, 0]}), + NameSpace( + predictions=[tool_fn("cat", **kwargs)], + ground_truths=[tool_fn("cat", **kwargs)], + expected={"cat": [1, 0, 0, 0]}, + ), NameSpace( predictions=[ - tool_fn("cat", - **kwargs, - subclasses=[get_radio("is_animal", answer_name="yes")]) + tool_fn( + "cat", + **kwargs, + subclasses=[get_radio("is_animal", answer_name="yes")], + ) ], ground_truths=[ - tool_fn("cat", - **kwargs, - subclasses=[get_radio("is_animal", answer_name="yes")]) + tool_fn( + "cat", + **kwargs, + subclasses=[get_radio("is_animal", answer_name="yes")], + ) ], - expected={'cat': [1, 0, 0, 0]}, - expected_without_subclasses={'cat': [1, 0, 0, 0]}), - NameSpace(predictions=[ - tool_fn("cat", + expected={"cat": [1, 0, 0, 0]}, + expected_without_subclasses={"cat": [1, 0, 0, 0]}, + ), + NameSpace( + predictions=[ + tool_fn( + "cat", **kwargs, - subclasses=[get_radio("is_animal", answer_name="yes")]) - ], - ground_truths=[ - tool_fn( - "cat", - **kwargs, - subclasses=[get_radio("is_animal", answer_name="no")]) - ], - expected={'cat': [0, 1, 0, 1]}, - expected_without_subclasses={'cat': [1, 0, 0, 0]}), - NameSpace(predictions=[ - tool_fn("cat", + subclasses=[get_radio("is_animal", answer_name="yes")], + ) + ], + ground_truths=[ + tool_fn( + "cat", **kwargs, - subclasses=[get_radio("is_animal", answer_name="yes")]), - tool_fn("cat", + subclasses=[get_radio("is_animal", answer_name="no")], + ) + ], + expected={"cat": [0, 1, 0, 1]}, + expected_without_subclasses={"cat": [1, 0, 0, 0]}, + ), + NameSpace( + predictions=[ + tool_fn( + "cat", **kwargs, - subclasses=[get_radio("is_animal", answer_name="no")]) - ], - ground_truths=[ - tool_fn( - "cat", - **kwargs, - subclasses=[get_radio("is_animal", answer_name="no")]) - ], - expected={'cat': [1, 1, 0, 0]}, - expected_without_subclasses={'cat': [1, 1, 0, 0]}), - NameSpace(predictions=[ - tool_fn("cat", + subclasses=[get_radio("is_animal", answer_name="yes")], + ), + tool_fn( + "cat", **kwargs, - subclasses=[get_radio("is_animal", answer_name="yes")]), - tool_fn("dog", + subclasses=[get_radio("is_animal", answer_name="no")], + ), + ], + ground_truths=[ + tool_fn( + "cat", **kwargs, - subclasses=[get_radio("is_animal", answer_name="no")]) - ], - ground_truths=[ - tool_fn( - "cat", - **kwargs, - subclasses=[get_radio("is_animal", answer_name="no")]) - ], - expected={ - 'cat': [0, 1, 0, 1], - 'dog': [0, 1, 0, 0] - }, - expected_without_subclasses={ - 'cat': [1, 0, 0, 0], - 'dog': [0, 1, 0, 0] - }), + subclasses=[get_radio("is_animal", answer_name="no")], + ) + ], + expected={"cat": [1, 1, 0, 0]}, + expected_without_subclasses={"cat": [1, 1, 0, 0]}, + ), NameSpace( - predictions=[tool_fn("cat", **kwargs), - tool_fn("cat", **kwargs)], - ground_truths=[tool_fn("cat", **kwargs), - tool_fn("cat", **kwargs)], - expected={'cat': [2, 0, 0, 0]}), + predictions=[ + tool_fn( + "cat", + **kwargs, + subclasses=[get_radio("is_animal", answer_name="yes")], + ), + tool_fn( + "dog", + **kwargs, + subclasses=[get_radio("is_animal", answer_name="no")], + ), + ], + ground_truths=[ + tool_fn( + "cat", + **kwargs, + subclasses=[get_radio("is_animal", answer_name="no")], + ) + ], + expected={"cat": [0, 1, 0, 1], "dog": [0, 1, 0, 0]}, + expected_without_subclasses={ + "cat": [1, 0, 0, 0], + "dog": [0, 1, 0, 0], + }, + ), + NameSpace( + predictions=[tool_fn("cat", **kwargs), tool_fn("cat", **kwargs)], + ground_truths=[tool_fn("cat", **kwargs), tool_fn("cat", **kwargs)], + expected={"cat": [2, 0, 0, 0]}, + ), NameSpace( - predictions=[tool_fn("cat", **kwargs), - tool_fn("cat", **kwargs)], + predictions=[tool_fn("cat", **kwargs), tool_fn("cat", **kwargs)], ground_truths=[tool_fn("cat", **kwargs)], - expected={'cat': [1, 1, 0, 0]}), + expected={"cat": [1, 1, 0, 0]}, + ), + NameSpace( + predictions=[tool_fn("cat", **kwargs)], + ground_truths=[tool_fn("cat", **kwargs), tool_fn("cat", **kwargs)], + expected={"cat": [1, 0, 0, 1]}, + ), + NameSpace( + predictions=[], + ground_truths=[], + expected=[], + expected_without_subclasses=[], + ), + NameSpace( + predictions=[], + ground_truths=[tool_fn("cat", **kwargs)], + expected={"cat": [0, 0, 0, 1]}, + ), + NameSpace( + predictions=[tool_fn("cat", **kwargs)], + ground_truths=[], + expected={"cat": [0, 1, 0, 0]}, + ), NameSpace( predictions=[tool_fn("cat", **kwargs)], - ground_truths=[tool_fn("cat", **kwargs), - tool_fn("cat", **kwargs)], - expected={'cat': [1, 0, 0, 1]}), - NameSpace(predictions=[], - ground_truths=[], - expected=[], - expected_without_subclasses=[]), - NameSpace(predictions=[], - ground_truths=[tool_fn("cat", **kwargs)], - expected={'cat': [0, 0, 0, 1]}), - NameSpace(predictions=[tool_fn("cat", **kwargs)], - ground_truths=[], - expected={'cat': [0, 1, 0, 0]}), - NameSpace(predictions=[tool_fn("cat", **kwargs)], - ground_truths=[tool_fn("dog", **kwargs)], - expected={ - 'cat': [0, 1, 0, 0], - 'dog': [0, 0, 0, 1] - }) + ground_truths=[tool_fn("dog", **kwargs)], + expected={"cat": [0, 1, 0, 0], "dog": [0, 0, 0, 1]}, + ), ] @pytest.fixture def radio_pairs(): return [ - NameSpace(predictions=[get_radio("is_animal", answer_name="yes")], - ground_truths=[get_radio("is_animal", answer_name="yes")], - expected={'yes': [1, 0, 0, 0]}), - NameSpace(predictions=[get_radio("is_animal", answer_name="yes")], - ground_truths=[get_radio("is_animal", answer_name="no")], - expected={ - 'no': [0, 0, 0, 1], - 'yes': [0, 1, 0, 0] - }), - NameSpace(predictions=[get_radio("is_animal", answer_name="yes")], - ground_truths=[], - expected={'yes': [0, 1, 0, 0]}), - NameSpace(predictions=[], - ground_truths=[get_radio("is_animal", answer_name="yes")], - expected={'yes': [0, 0, 0, 1]}), - NameSpace(predictions=[ - get_radio("is_animal", answer_name="yes"), - get_radio("is_short", answer_name="no") - ], - ground_truths=[get_radio("is_animal", answer_name="yes")], - expected={ - 'no': [0, 1, 0, 0], - 'yes': [1, 0, 0, 0] - }), - #Not supported yet: + NameSpace( + predictions=[get_radio("is_animal", answer_name="yes")], + ground_truths=[get_radio("is_animal", answer_name="yes")], + expected={"yes": [1, 0, 0, 0]}, + ), + NameSpace( + predictions=[get_radio("is_animal", answer_name="yes")], + ground_truths=[get_radio("is_animal", answer_name="no")], + expected={"no": [0, 0, 0, 1], "yes": [0, 1, 0, 0]}, + ), + NameSpace( + predictions=[get_radio("is_animal", answer_name="yes")], + ground_truths=[], + expected={"yes": [0, 1, 0, 0]}, + ), + NameSpace( + predictions=[], + ground_truths=[get_radio("is_animal", answer_name="yes")], + expected={"yes": [0, 0, 0, 1]}, + ), + NameSpace( + predictions=[ + get_radio("is_animal", answer_name="yes"), + get_radio("is_short", answer_name="no"), + ], + ground_truths=[get_radio("is_animal", answer_name="yes")], + expected={"no": [0, 1, 0, 0], "yes": [1, 0, 0, 0]}, + ), + # Not supported yet: # NameSpace( - #predictions=[], - #ground_truths=[], - #expected = [0,0,1,0] - #) + # predictions=[], + # ground_truths=[], + # expected = [0,0,1,0] + # ) ] @pytest.fixture def checklist_pairs(): return [ - NameSpace(predictions=[ - get_checklist("animal_attributes", answer_names=["striped"]) - ], - ground_truths=[ - get_checklist("animal_attributes", - answer_names=["striped"]) - ], - expected={'striped': [1, 0, 0, 0]}), - NameSpace(predictions=[ - get_checklist("animal_attributes", answer_names=["striped"]) - ], - ground_truths=[], - expected={'striped': [0, 1, 0, 0]}), - NameSpace(predictions=[], - ground_truths=[ - get_checklist("animal_attributes", - answer_names=["striped"]) - ], - expected={'striped': [0, 0, 0, 1]}), - NameSpace(predictions=[ - get_checklist("animal_attributes", - answer_names=["striped", "short"]) - ], - ground_truths=[ - get_checklist("animal_attributes", - answer_names=["striped"]) - ], - expected={ - 'short': [0, 1, 0, 0], - 'striped': [1, 0, 0, 0] - }), - NameSpace(predictions=[ - get_checklist("animal_attributes", answer_names=["striped"]) - ], - ground_truths=[ - get_checklist("animal_attributes", - answer_names=["striped", "short"]) - ], - expected={ - 'short': [0, 0, 0, 1], - 'striped': [1, 0, 0, 0] - }), - NameSpace(predictions=[ - get_checklist("animal_attributes", - answer_names=["striped", "short", "black"]) - ], - ground_truths=[ - get_checklist("animal_attributes", - answer_names=["striped", "short"]) - ], - expected={ - 'black': [0, 1, 0, 0], - 'short': [1, 0, 0, 0], - 'striped': [1, 0, 0, 0] - }), - NameSpace(predictions=[ - get_checklist("animal_attributes", - answer_names=["striped", "short", "black"]), - get_checklist("animal_name", answer_names=["doggy", "pup"]) - ], - ground_truths=[ - get_checklist("animal_attributes", - answer_names=["striped", "short"]), - get_checklist("animal_name", answer_names=["pup"]) - ], - expected={ - 'black': [0, 1, 0, 0], - 'doggy': [0, 1, 0, 0], - 'pup': [1, 0, 0, 0], - 'short': [1, 0, 0, 0], - 'striped': [1, 0, 0, 0] - }) - - #Not supported yet: + NameSpace( + predictions=[ + get_checklist("animal_attributes", answer_names=["striped"]) + ], + ground_truths=[ + get_checklist("animal_attributes", answer_names=["striped"]) + ], + expected={"striped": [1, 0, 0, 0]}, + ), + NameSpace( + predictions=[ + get_checklist("animal_attributes", answer_names=["striped"]) + ], + ground_truths=[], + expected={"striped": [0, 1, 0, 0]}, + ), + NameSpace( + predictions=[], + ground_truths=[ + get_checklist("animal_attributes", answer_names=["striped"]) + ], + expected={"striped": [0, 0, 0, 1]}, + ), + NameSpace( + predictions=[ + get_checklist( + "animal_attributes", answer_names=["striped", "short"] + ) + ], + ground_truths=[ + get_checklist("animal_attributes", answer_names=["striped"]) + ], + expected={"short": [0, 1, 0, 0], "striped": [1, 0, 0, 0]}, + ), + NameSpace( + predictions=[ + get_checklist("animal_attributes", answer_names=["striped"]) + ], + ground_truths=[ + get_checklist( + "animal_attributes", answer_names=["striped", "short"] + ) + ], + expected={"short": [0, 0, 0, 1], "striped": [1, 0, 0, 0]}, + ), + NameSpace( + predictions=[ + get_checklist( + "animal_attributes", + answer_names=["striped", "short", "black"], + ) + ], + ground_truths=[ + get_checklist( + "animal_attributes", answer_names=["striped", "short"] + ) + ], + expected={ + "black": [0, 1, 0, 0], + "short": [1, 0, 0, 0], + "striped": [1, 0, 0, 0], + }, + ), + NameSpace( + predictions=[ + get_checklist( + "animal_attributes", + answer_names=["striped", "short", "black"], + ), + get_checklist("animal_name", answer_names=["doggy", "pup"]), + ], + ground_truths=[ + get_checklist( + "animal_attributes", answer_names=["striped", "short"] + ), + get_checklist("animal_name", answer_names=["pup"]), + ], + expected={ + "black": [0, 1, 0, 0], + "doggy": [0, 1, 0, 0], + "pup": [1, 0, 0, 0], + "short": [1, 0, 0, 0], + "striped": [1, 0, 0, 0], + }, + ), + # Not supported yet: # NameSpace( - #predictions=[], - #ground_truths=[], - #expected = [0,0,1,0] - #) + # predictions=[], + # ground_truths=[], + # expected = [0,0,1,0] + # ) ] @pytest.fixture def polygon_pairs(): - return get_object_pairs(get_polygon, - points=[[0, 0], [10, 0], [10, 10], [0, 10]]) + return get_object_pairs( + get_polygon, points=[[0, 0], [10, 0], [10, 10], [0, 10]] + ) @pytest.fixture @@ -342,8 +419,9 @@ def mask_pairs(): @pytest.fixture def line_pairs(): - return get_object_pairs(get_line, - points=[[0, 0], [10, 0], [10, 10], [0, 10]]) + return get_object_pairs( + get_line, points=[[0, 0], [10, 0], [10, 10], [0, 10]] + ) @pytest.fixture @@ -359,47 +437,39 @@ def ner_pairs(): @pytest.fixture() def pair_iou_thresholds(): return [ - NameSpace(predictions=[ - get_polygon("cat", points=[[0, 0], [10, 0], [10, 10], [0, 10]]), - ], - ground_truths=[ - get_polygon("cat", - points=[[0, 0], [5, 0], [5, 5], [0, 5]]), - ], - expected={ - 0.2: [1, 0, 0, 0], - 0.3: [0, 1, 0, 1] - }), + NameSpace( + predictions=[ + get_polygon("cat", points=[[0, 0], [10, 0], [10, 10], [0, 10]]), + ], + ground_truths=[ + get_polygon("cat", points=[[0, 0], [5, 0], [5, 5], [0, 5]]), + ], + expected={0.2: [1, 0, 0, 0], 0.3: [0, 1, 0, 1]}, + ), NameSpace( predictions=[get_rectangle("cat", start=[0, 0], end=[10, 10])], ground_truths=[get_rectangle("cat", start=[0, 0], end=[5, 5])], - expected={ - 0.2: [1, 0, 0, 0], - 0.3: [0, 1, 0, 1] - }), - NameSpace(predictions=[get_point("cat", x=0, y=0)], - ground_truths=[get_point("cat", x=20, y=20)], - expected={ - 0.5: [1, 0, 0, 0], - 0.65: [0, 1, 0, 1] - }), - NameSpace(predictions=[ - get_line("cat", points=[[0, 0], [10, 0], [10, 10], [0, 10]]) - ], - ground_truths=[ - get_line("cat", - points=[[0, 0], [100, 0], [100, 100], [0, 100]]) - ], - expected={ - 0.3: [1, 0, 0, 0], - 0.65: [0, 1, 0, 1] - }), - NameSpace(predictions=[ - get_mask("cat", pixels=[[0, 0], [1, 1], [2, 2], [3, 3]]) - ], - ground_truths=[get_mask("cat", pixels=[[0, 0], [1, 1]])], - expected={ - 0.4: [1, 0, 0, 0], - 0.6: [0, 1, 0, 1] - }), + expected={0.2: [1, 0, 0, 0], 0.3: [0, 1, 0, 1]}, + ), + NameSpace( + predictions=[get_point("cat", x=0, y=0)], + ground_truths=[get_point("cat", x=20, y=20)], + expected={0.5: [1, 0, 0, 0], 0.65: [0, 1, 0, 1]}, + ), + NameSpace( + predictions=[ + get_line("cat", points=[[0, 0], [10, 0], [10, 10], [0, 10]]) + ], + ground_truths=[ + get_line("cat", points=[[0, 0], [100, 0], [100, 100], [0, 100]]) + ], + expected={0.3: [1, 0, 0, 0], 0.65: [0, 1, 0, 1]}, + ), + NameSpace( + predictions=[ + get_mask("cat", pixels=[[0, 0], [1, 1], [2, 2], [3, 3]]) + ], + ground_truths=[get_mask("cat", pixels=[[0, 0], [1, 1]])], + expected={0.4: [1, 0, 0, 0], 0.6: [0, 1, 0, 1]}, + ), ] diff --git a/libs/labelbox/tests/data/metrics/confusion_matrix/test_confusion_matrix_data_row.py b/libs/labelbox/tests/data/metrics/confusion_matrix/test_confusion_matrix_data_row.py index e84207ac2..e3ac86213 100644 --- a/libs/labelbox/tests/data/metrics/confusion_matrix/test_confusion_matrix_data_row.py +++ b/libs/labelbox/tests/data/metrics/confusion_matrix/test_confusion_matrix_data_row.py @@ -1,47 +1,57 @@ from pytest_cases import fixture_ref from pytest_cases import parametrize, fixture_ref -from labelbox.data.metrics.confusion_matrix.confusion_matrix import confusion_matrix_metric +from labelbox.data.metrics.confusion_matrix.confusion_matrix import ( + confusion_matrix_metric, +) -@parametrize("tool_examples", [ - fixture_ref('polygon_pairs'), - fixture_ref('rectangle_pairs'), - fixture_ref('mask_pairs'), - fixture_ref('line_pairs'), - fixture_ref('point_pairs'), - fixture_ref('ner_pairs') -]) +@parametrize( + "tool_examples", + [ + fixture_ref("polygon_pairs"), + fixture_ref("rectangle_pairs"), + fixture_ref("mask_pairs"), + fixture_ref("line_pairs"), + fixture_ref("point_pairs"), + fixture_ref("ner_pairs"), + ], +) def test_overlapping_objects(tool_examples): for example in tool_examples: - - for include_subclasses, expected_attr_name in [[ - True, 'expected' - ], [False, 'expected_without_subclasses']]: + for include_subclasses, expected_attr_name in [ + [True, "expected"], + [False, "expected_without_subclasses"], + ]: score = confusion_matrix_metric( example.ground_truths, example.predictions, - include_subclasses=include_subclasses) + include_subclasses=include_subclasses, + ) if len(getattr(example, expected_attr_name)) == 0: assert len(score) == 0 else: expected = [0, 0, 0, 0] - for expected_values in getattr(example, - expected_attr_name).values(): + for expected_values in getattr( + example, expected_attr_name + ).values(): for idx in range(4): expected[idx] += expected_values[idx] assert score[0].value == tuple( - expected), f"{example.predictions},{example.ground_truths}" + expected + ), f"{example.predictions},{example.ground_truths}" -@parametrize("tool_examples", - [fixture_ref('checklist_pairs'), - fixture_ref('radio_pairs')]) +@parametrize( + "tool_examples", + [fixture_ref("checklist_pairs"), fixture_ref("radio_pairs")], +) def test_overlapping_classifications(tool_examples): for example in tool_examples: - score = confusion_matrix_metric(example.ground_truths, - example.predictions) + score = confusion_matrix_metric( + example.ground_truths, example.predictions + ) if len(example.expected) == 0: assert len(score) == 0 else: @@ -50,15 +60,16 @@ def test_overlapping_classifications(tool_examples): for idx in range(4): expected[idx] += expected_values[idx] assert score[0].value == tuple( - expected), f"{example.predictions},{example.ground_truths}" + expected + ), f"{example.predictions},{example.ground_truths}" def test_partial_overlap(pair_iou_thresholds): for example in pair_iou_thresholds: for iou in example.expected.keys(): - score = confusion_matrix_metric(example.predictions, - example.ground_truths, - iou=iou) + score = confusion_matrix_metric( + example.predictions, example.ground_truths, iou=iou + ) assert score[0].value == tuple( example.expected[iou] ), f"{example.predictions},{example.ground_truths}" diff --git a/libs/labelbox/tests/data/metrics/confusion_matrix/test_confusion_matrix_feature.py b/libs/labelbox/tests/data/metrics/confusion_matrix/test_confusion_matrix_feature.py index f55555e75..818c01f72 100644 --- a/libs/labelbox/tests/data/metrics/confusion_matrix/test_confusion_matrix_feature.py +++ b/libs/labelbox/tests/data/metrics/confusion_matrix/test_confusion_matrix_feature.py @@ -1,26 +1,33 @@ from pytest_cases import fixture_ref from pytest_cases import parametrize, fixture_ref -from labelbox.data.metrics.confusion_matrix.confusion_matrix import feature_confusion_matrix_metric - - -@parametrize("tool_examples", [ - fixture_ref('polygon_pairs'), - fixture_ref('rectangle_pairs'), - fixture_ref('mask_pairs'), - fixture_ref('line_pairs'), - fixture_ref('point_pairs'), - fixture_ref('ner_pairs') -]) +from labelbox.data.metrics.confusion_matrix.confusion_matrix import ( + feature_confusion_matrix_metric, +) + + +@parametrize( + "tool_examples", + [ + fixture_ref("polygon_pairs"), + fixture_ref("rectangle_pairs"), + fixture_ref("mask_pairs"), + fixture_ref("line_pairs"), + fixture_ref("point_pairs"), + fixture_ref("ner_pairs"), + ], +) def test_overlapping_objects(tool_examples): for example in tool_examples: - for include_subclasses, expected_attr_name in [[ - True, 'expected' - ], [False, 'expected_without_subclasses']]: + for include_subclasses, expected_attr_name in [ + [True, "expected"], + [False, "expected_without_subclasses"], + ]: metrics = feature_confusion_matrix_metric( example.ground_truths, example.predictions, - include_subclasses=include_subclasses) + include_subclasses=include_subclasses, + ) metrics = {r.feature_name: list(r.value) for r in metrics} if len(getattr(example, expected_attr_name)) == 0: @@ -31,17 +38,20 @@ def test_overlapping_objects(tool_examples): ), f"{example.predictions},{example.ground_truths}" -@parametrize("tool_examples", - [fixture_ref('checklist_pairs'), - fixture_ref('radio_pairs')]) +@parametrize( + "tool_examples", + [fixture_ref("checklist_pairs"), fixture_ref("radio_pairs")], +) def test_overlapping_classifications(tool_examples): for example in tool_examples: - - metrics = feature_confusion_matrix_metric(example.ground_truths, - example.predictions) + metrics = feature_confusion_matrix_metric( + example.ground_truths, example.predictions + ) metrics = {r.feature_name: list(r.value) for r in metrics} if len(example.expected) == 0: assert len(metrics) == 0 else: - assert metrics == example.expected, f"{example.predictions},{example.ground_truths}" + assert ( + metrics == example.expected + ), f"{example.predictions},{example.ground_truths}" diff --git a/libs/labelbox/tests/data/metrics/iou/data_row/conftest.py b/libs/labelbox/tests/data/metrics/iou/data_row/conftest.py index d25abe2cf..6614cecf4 100644 --- a/libs/labelbox/tests/data/metrics/iou/data_row/conftest.py +++ b/libs/labelbox/tests/data/metrics/iou/data_row/conftest.py @@ -7,780 +7,696 @@ class NameSpace(SimpleNamespace): - - def __init__(self, - predictions, - labels, - expected, - expected_without_subclasses=None, - data_row_expected=None, - media_attributes=None, - metadata=None, - classifications=None): + def __init__( + self, + predictions, + labels, + expected, + expected_without_subclasses=None, + data_row_expected=None, + media_attributes=None, + metadata=None, + classifications=None, + ): super(NameSpace, self).__init__( predictions=predictions, labels={ - 'DataRow ID': 'ckppihxc10005aeyjen11h7jh', - 'Labeled Data': "https://.jpg", - 'Media Attributes': media_attributes or {}, - 'DataRow Metadata': metadata or [], - 'Label': { - 'objects': labels, - 'classifications': classifications or [] - } + "DataRow ID": "ckppihxc10005aeyjen11h7jh", + "Labeled Data": "https://.jpg", + "Media Attributes": media_attributes or {}, + "DataRow Metadata": metadata or [], + "Label": { + "objects": labels, + "classifications": classifications or [], + }, }, expected=expected, expected_without_subclasses=expected_without_subclasses or expected, - data_row_expected=data_row_expected) + data_row_expected=data_row_expected, + ) @pytest.fixture def polygon_pair(): - return NameSpace(labels=[{ - 'featureId': - 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'polygon': [{ - 'x': 0, - 'y': 0 - }, { - 'x': 1, - 'y': 0 - }, { - 'x': 1, - 'y': 1 - }, { - 'x': 0, - 'y': 1 - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'polygon': [{ - 'x': 0, - 'y': 0 - }, { - 'x': 1, - 'y': 0 - }, { - 'x': 1, - 'y': 0.5 - }, { - 'x': 0, - 'y': 0.5 - }] - }], - expected=0.5) + return NameSpace( + labels=[ + { + "featureId": "ckppivl7p0006aeyj92cezr9d", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "polygon": [ + {"x": 0, "y": 0}, + {"x": 1, "y": 0}, + {"x": 1, "y": 1}, + {"x": 0, "y": 1}, + ], + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "polygon": [ + {"x": 0, "y": 0}, + {"x": 1, "y": 0}, + {"x": 1, "y": 0.5}, + {"x": 0, "y": 0.5}, + ], + } + ], + expected=0.5, + ) @pytest.fixture def box_pair(): - return NameSpace(labels=[{ - 'featureId': 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "bbox": { - "top": 1099, - "left": 2010, - "height": 690, - "width": 591 - } - }], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - "bbox": { - "top": 1099, - "left": 2010, - "height": 690, - "width": 591 - } - }], - expected=1.0) + return NameSpace( + labels=[ + { + "featureId": "ckppivl7p0006aeyj92cezr9d", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "bbox": { + "top": 1099, + "left": 2010, + "height": 690, + "width": 591, + }, + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "bbox": { + "top": 1099, + "left": 2010, + "height": 690, + "width": 591, + }, + } + ], + expected=1.0, + ) @pytest.fixture def unmatched_prediction(): - return NameSpace(labels=[{ - 'featureId': - 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'polygon': [{ - 'x': 0, - 'y': 0 - }, { - 'x': 1, - 'y': 0 - }, { - 'x': 1, - 'y': 1 - }, { - 'x': 0, - 'y': 1 - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'polygon': [{ - 'x': 0, - 'y': 0 - }, { - 'x': 1, - 'y': 0 - }, { - 'x': 1, - 'y': 0.5 - }, { - 'x': 0, - 'y': 0.5 - }] - }, { - 'uuid': - 'd0ba2520-02e9-47d4-8736-088bbdbabbc3', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'polygon': [{ - 'x': 10, - 'y': 10 - }, { - 'x': 11, - 'y': 10 - }, { - 'x': 11, - 'y': 1.5 - }, { - 'x': 10, - 'y': 1.5 - }] - }], - expected=0.25) + return NameSpace( + labels=[ + { + "featureId": "ckppivl7p0006aeyj92cezr9d", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "polygon": [ + {"x": 0, "y": 0}, + {"x": 1, "y": 0}, + {"x": 1, "y": 1}, + {"x": 0, "y": 1}, + ], + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "polygon": [ + {"x": 0, "y": 0}, + {"x": 1, "y": 0}, + {"x": 1, "y": 0.5}, + {"x": 0, "y": 0.5}, + ], + }, + { + "uuid": "d0ba2520-02e9-47d4-8736-088bbdbabbc3", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "polygon": [ + {"x": 10, "y": 10}, + {"x": 11, "y": 10}, + {"x": 11, "y": 1.5}, + {"x": 10, "y": 1.5}, + ], + }, + ], + expected=0.25, + ) @pytest.fixture def unmatched_label(): - return NameSpace(labels=[{ - 'featureId': - '1234567890111213141516171', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'polygon': [{ - 'x': 0, - 'y': 0 - }, { - 'x': 1, - 'y': 0 - }, { - 'x': 1, - 'y': 1 - }, { - 'x': 0, - 'y': 1 - }] - }, { - 'featureId': - 'ckppiw3bs0007aeyjs3pvrqzi', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'polygon': [{ - 'x': 10, - 'y': 10 - }, { - 'x': 11, - 'y': 10 - }, { - 'x': 11, - 'y': 11 - }, { - 'x': 10, - 'y': 11 - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'polygon': [{ - 'x': 0, - 'y': 0 - }, { - 'x': 1, - 'y': 0 - }, { - 'x': 1, - 'y': 0.5 - }, { - 'x': 0, - 'y': 0.5 - }] - }], - expected=0.25) + return NameSpace( + labels=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "polygon": [ + {"x": 0, "y": 0}, + {"x": 1, "y": 0}, + {"x": 1, "y": 1}, + {"x": 0, "y": 1}, + ], + }, + { + "featureId": "ckppiw3bs0007aeyjs3pvrqzi", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "polygon": [ + {"x": 10, "y": 10}, + {"x": 11, "y": 10}, + {"x": 11, "y": 11}, + {"x": 10, "y": 11}, + ], + }, + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "polygon": [ + {"x": 0, "y": 0}, + {"x": 1, "y": 0}, + {"x": 1, "y": 0.5}, + {"x": 0, "y": 0.5}, + ], + } + ], + expected=0.25, + ) def create_mask_url(indices, h, w, value): mask = np.zeros((h, w, 3), dtype=np.uint8) for idx in indices: mask[idx] = value - return base64.b64encode(mask.tobytes()).decode('utf-8') + return base64.b64encode(mask.tobytes()).decode("utf-8") @pytest.fixture def mask_pair(): - return NameSpace(labels=[{ - 'featureId': - '1234567890111213141516171', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'instanceURI': - create_mask_url([(0, 0), (0, 1)], 32, 32, (255, 255, 255)) - }], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'mask': { - 'instanceURI': - create_mask_url([(0, 0)], 32, 32, (1, 1, 1)), - 'colorRGB': (1, 1, 1) - } - }], - expected=0.5) + return NameSpace( + labels=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "instanceURI": create_mask_url( + [(0, 0), (0, 1)], 32, 32, (255, 255, 255) + ), + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "mask": { + "instanceURI": create_mask_url([(0, 0)], 32, 32, (1, 1, 1)), + "colorRGB": (1, 1, 1), + }, + } + ], + expected=0.5, + ) @pytest.fixture def matching_radio(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': '1234567890111213141516171', - 'schemaId': 'ckrm02no8000008l3arwp6h4f', - 'answer': { - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t' - } - }], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckrm02no8000008l3arwp6h4f', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answer': { - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t' - } - }], - expected=1.) + return NameSpace( + labels=[], + classifications=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckrm02no8000008l3arwp6h4f", + "answer": {"schemaId": "ckppid25v0000aeyjmxfwlc7t"}, + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckrm02no8000008l3arwp6h4f", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "answer": {"schemaId": "ckppid25v0000aeyjmxfwlc7t"}, + } + ], + expected=1.0, + ) @pytest.fixture def empty_radio_label(): - return NameSpace(labels=[], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answer': { - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t' - } - }], - expected=0) + return NameSpace( + labels=[], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "answer": {"schemaId": "ckppid25v0000aeyjmxfwlc7t"}, + } + ], + expected=0, + ) @pytest.fixture def empty_radio_prediction(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': '1234567890111213141516171', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': { - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t' - } - }], - predictions=[], - expected=0) + return NameSpace( + labels=[], + classifications=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "answer": {"schemaId": "ckppid25v0000aeyjmxfwlc7t"}, + } + ], + predictions=[], + expected=0, + ) @pytest.fixture def matching_checklist(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': - '1234567890111213141516171', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }, { - 'schemaId': 'ckppidq4u0002aeyjmcc4toxw' - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }, { - 'schemaId': 'ckppidq4u0002aeyjmcc4toxw' - }] - }], - data_row_expected=1., - expected={1.0: 3}) + return NameSpace( + labels=[], + classifications=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "answers": [ + { + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + }, + {"schemaId": "ckppide010001aeyj0yhiaghc"}, + {"schemaId": "ckppidq4u0002aeyjmcc4toxw"}, + ], + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "answers": [ + { + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + }, + {"schemaId": "ckppide010001aeyj0yhiaghc"}, + {"schemaId": "ckppidq4u0002aeyjmcc4toxw"}, + ], + } + ], + data_row_expected=1.0, + expected={1.0: 3}, + ) @pytest.fixture def partially_matching_checklist_1(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': - '1234567890111213141516171', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }, { - 'schemaId': 'ckppidq4u0002aeyjmcc4toxw' - }, { - 'schemaId': 'ckppie29m0003aeyjk1ixzcom' - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }, { - 'schemaId': 'ckppidq4u0002aeyjmcc4toxw' - }, { - 'schemaId': 'ckppiebx80004aeyjuwvos69e' - }] - }], - data_row_expected=0.6, - expected={ - 0.0: 2, - 1.0: 3 - }) + return NameSpace( + labels=[], + classifications=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "answers": [ + { + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + }, + {"schemaId": "ckppide010001aeyj0yhiaghc"}, + {"schemaId": "ckppidq4u0002aeyjmcc4toxw"}, + {"schemaId": "ckppie29m0003aeyjk1ixzcom"}, + ], + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "answers": [ + { + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + }, + {"schemaId": "ckppide010001aeyj0yhiaghc"}, + {"schemaId": "ckppidq4u0002aeyjmcc4toxw"}, + {"schemaId": "ckppiebx80004aeyjuwvos69e"}, + ], + } + ], + data_row_expected=0.6, + expected={0.0: 2, 1.0: 3}, + ) @pytest.fixture def partially_matching_checklist_2(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': - '1234567890111213141516171', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }, { - 'schemaId': 'ckppidq4u0002aeyjmcc4toxw' - }, { - 'schemaId': 'ckppiebx80004aeyjuwvos69e' - }] - }], - data_row_expected=0.5, - expected={ - 1.0: 2, - 0.0: 2 - }) + return NameSpace( + labels=[], + classifications=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "answers": [ + { + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + }, + {"schemaId": "ckppide010001aeyj0yhiaghc"}, + ], + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "answers": [ + { + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + }, + {"schemaId": "ckppide010001aeyj0yhiaghc"}, + {"schemaId": "ckppidq4u0002aeyjmcc4toxw"}, + {"schemaId": "ckppiebx80004aeyjuwvos69e"}, + ], + } + ], + data_row_expected=0.5, + expected={1.0: 2, 0.0: 2}, + ) @pytest.fixture def partially_matching_checklist_3(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': - '1234567890111213141516171', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }, { - 'schemaId': 'ckppidq4u0002aeyjmcc4toxw' - }, { - 'schemaId': 'ckppiebx80004aeyjuwvos69e' - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }] - }], - data_row_expected=0.5, - expected={ - 1.0: 2, - 0.0: 2 - }) + return NameSpace( + labels=[], + classifications=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "answers": [ + { + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + }, + {"schemaId": "ckppide010001aeyj0yhiaghc"}, + {"schemaId": "ckppidq4u0002aeyjmcc4toxw"}, + {"schemaId": "ckppiebx80004aeyjuwvos69e"}, + ], + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "answers": [ + { + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + }, + {"schemaId": "ckppide010001aeyj0yhiaghc"}, + ], + } + ], + data_row_expected=0.5, + expected={1.0: 2, 0.0: 2}, + ) @pytest.fixture def empty_checklist_label(): - return NameSpace(labels=[], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t' - }] - }], - data_row_expected=0.0, - expected={0.0: 1}) + return NameSpace( + labels=[], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "answers": [{"schemaId": "ckppid25v0000aeyjmxfwlc7t"}], + } + ], + data_row_expected=0.0, + expected={0.0: 1}, + ) @pytest.fixture def empty_checklist_prediction(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': '1234567890111213141516171', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t' - }] - }], - predictions=[], - data_row_expected=0.0, - expected={0.0: 1}) + return NameSpace( + labels=[], + classifications=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "answers": [{"schemaId": "ckppid25v0000aeyjmxfwlc7t"}], + } + ], + predictions=[], + data_row_expected=0.0, + expected={0.0: 1}, + ) @pytest.fixture def matching_text(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': '1234567890111213141516171', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': 'test' - }], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answer': 'test' - }], - expected=1.0) + return NameSpace( + labels=[], + classifications=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "answer": "test", + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "answer": "test", + } + ], + expected=1.0, + ) @pytest.fixture def not_matching_text(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': '1234567890111213141516171', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': 'test' - }], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answer': 'not_test' - }], - expected=0.) + return NameSpace( + labels=[], + classifications=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "answer": "test", + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "answer": "not_test", + } + ], + expected=0.0, + ) @pytest.fixture def test_box_with_subclass(): - return NameSpace(labels=[{ - 'featureId': - 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - "bbox": { - "top": 1099, - "left": 2010, - "height": 690, - "width": 591 - }, - 'classifications': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': 'test' - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - "bbox": { - "top": 1099, - "left": 2010, - "height": 690, - "width": 591 - }, - 'classifications': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': 'test' - }] - }], - expected=1.0) + return NameSpace( + labels=[ + { + "featureId": "ckppivl7p0006aeyj92cezr9d", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "bbox": { + "top": 1099, + "left": 2010, + "height": 690, + "width": 591, + }, + "classifications": [ + {"schemaId": "ckppid25v0000aeyjmxfwlc7t", "answer": "test"} + ], + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "bbox": { + "top": 1099, + "left": 2010, + "height": 690, + "width": 591, + }, + "classifications": [ + {"schemaId": "ckppid25v0000aeyjmxfwlc7t", "answer": "test"} + ], + } + ], + expected=1.0, + ) @pytest.fixture def test_box_with_wrong_subclass(): - return NameSpace(labels=[{ - 'featureId': - 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - "bbox": { - "top": 1099, - "left": 2010, - "height": 690, - "width": 591 - }, - 'classifications': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': 'test' - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - "bbox": { - "top": 1099, - "left": 2010, - "height": 690, - "width": 591 - }, - 'classifications': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': 'not_test' - }] - }], - expected=0.5, - expected_without_subclasses=1.0) + return NameSpace( + labels=[ + { + "featureId": "ckppivl7p0006aeyj92cezr9d", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "bbox": { + "top": 1099, + "left": 2010, + "height": 690, + "width": 591, + }, + "classifications": [ + {"schemaId": "ckppid25v0000aeyjmxfwlc7t", "answer": "test"} + ], + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "bbox": { + "top": 1099, + "left": 2010, + "height": 690, + "width": 591, + }, + "classifications": [ + { + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "answer": "not_test", + } + ], + } + ], + expected=0.5, + expected_without_subclasses=1.0, + ) @pytest.fixture def line_pair(): - return NameSpace(labels=[{ - 'featureId': 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "line": [{ - "x": 0, - "y": 100 - }, { - "x": 0, - "y": 0 - }], - }], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - "line": [{ - "x": 5, - "y": 95 - }, { - "x": 0, - "y": 0 - }], - }], - expected=0.9496975567603978) + return NameSpace( + labels=[ + { + "featureId": "ckppivl7p0006aeyj92cezr9d", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "line": [{"x": 0, "y": 100}, {"x": 0, "y": 0}], + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "line": [{"x": 5, "y": 95}, {"x": 0, "y": 0}], + } + ], + expected=0.9496975567603978, + ) @pytest.fixture def point_pair(): - return NameSpace(labels=[{ - 'featureId': 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "point": { - 'x': 0, - 'y': 0 - } - }], - predictions=[{ - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "point": { - 'x': 5, - 'y': 5 - } - }], - expected=0.879113232477017) + return NameSpace( + labels=[ + { + "featureId": "ckppivl7p0006aeyj92cezr9d", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "point": {"x": 0, "y": 0}, + } + ], + predictions=[ + { + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "point": {"x": 5, "y": 5}, + } + ], + expected=0.879113232477017, + ) @pytest.fixture def matching_ner(): - return NameSpace(labels=[{ - 'featureId': 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'format': "text.location", - 'data': { - "location": { - "start": 0, - "end": 10 - } - } - }], - predictions=[{ - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "location": { - "start": 0, - "end": 10 - } - }], - expected=1) + return NameSpace( + labels=[ + { + "featureId": "ckppivl7p0006aeyj92cezr9d", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "format": "text.location", + "data": {"location": {"start": 0, "end": 10}}, + } + ], + predictions=[ + { + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "location": {"start": 0, "end": 10}, + } + ], + expected=1, + ) @pytest.fixture def no_matching_ner(): - return NameSpace(labels=[{ - 'featureId': 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'format': "text.location", - 'data': { - "location": { - "start": 0, - "end": 5 - } - } - }], - predictions=[{ - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "location": { - "start": 5, - "end": 10 - } - }], - expected=0) + return NameSpace( + labels=[ + { + "featureId": "ckppivl7p0006aeyj92cezr9d", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "format": "text.location", + "data": {"location": {"start": 0, "end": 5}}, + } + ], + predictions=[ + { + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "location": {"start": 5, "end": 10}, + } + ], + expected=0, + ) @pytest.fixture def partial_matching_ner(): - return NameSpace(labels=[{ - 'featureId': 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'format': "text.location", - 'data': { - "location": { - "start": 0, - "end": 7 - } - } - }], - predictions=[{ - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "location": { - "start": 3, - "end": 5 - } - }], - expected=0.2857142857142857) + return NameSpace( + labels=[ + { + "featureId": "ckppivl7p0006aeyj92cezr9d", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "format": "text.location", + "data": {"location": {"start": 0, "end": 7}}, + } + ], + predictions=[ + { + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "location": {"start": 3, "end": 5}, + } + ], + expected=0.2857142857142857, + ) diff --git a/libs/labelbox/tests/data/metrics/iou/feature/conftest.py b/libs/labelbox/tests/data/metrics/iou/feature/conftest.py index c89d30056..c3b2a28e3 100644 --- a/libs/labelbox/tests/data/metrics/iou/feature/conftest.py +++ b/libs/labelbox/tests/data/metrics/iou/feature/conftest.py @@ -2,107 +2,140 @@ import pytest -from labelbox.data.annotation_types import ClassificationAnnotation, ObjectAnnotation +from labelbox.data.annotation_types import ( + ClassificationAnnotation, + ObjectAnnotation, +) from labelbox.data.annotation_types import Polygon, Point class NameSpace(SimpleNamespace): - def __init__(self, predictions, ground_truths, expected): - super(NameSpace, self).__init__(predictions=predictions, - ground_truths=ground_truths, - expected=expected) + super(NameSpace, self).__init__( + predictions=predictions, + ground_truths=ground_truths, + expected=expected, + ) @pytest.fixture def different_classes(): return [ - NameSpace(predictions=[ - ObjectAnnotation(name="cat", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])) - ], - ground_truths=[ - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])) - ], - expected={ - 'cat': 0, - 'dog': 0 - }) + NameSpace( + predictions=[ + ObjectAnnotation( + name="cat", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=10, y=0), + Point(x=10, y=10), + Point(x=0, y=10), + ] + ), + ) + ], + ground_truths=[ + ObjectAnnotation( + name="dog", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=10, y=0), + Point(x=10, y=10), + Point(x=0, y=10), + ] + ), + ) + ], + expected={"cat": 0, "dog": 0}, + ) ] @pytest.fixture def one_overlap_class(): return [ - NameSpace(predictions=[ - ObjectAnnotation(name="cat", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])), - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=5, y=0), - Point(x=5, y=5), - Point(x=0, y=5) - ])) - ], - ground_truths=[ - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])) - ], - expected={ - 'dog': 0.25, - 'cat': 0. - }), - NameSpace(predictions=[ - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=5, y=0), - Point(x=5, y=5), - Point(x=0, y=5) - ])) - ], - ground_truths=[ - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])), - ObjectAnnotation(name="cat", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])) - ], - expected={ - 'dog': 0.25, - 'cat': 0. - }) + NameSpace( + predictions=[ + ObjectAnnotation( + name="cat", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=10, y=0), + Point(x=10, y=10), + Point(x=0, y=10), + ] + ), + ), + ObjectAnnotation( + name="dog", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=5, y=0), + Point(x=5, y=5), + Point(x=0, y=5), + ] + ), + ), + ], + ground_truths=[ + ObjectAnnotation( + name="dog", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=10, y=0), + Point(x=10, y=10), + Point(x=0, y=10), + ] + ), + ) + ], + expected={"dog": 0.25, "cat": 0.0}, + ), + NameSpace( + predictions=[ + ObjectAnnotation( + name="dog", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=5, y=0), + Point(x=5, y=5), + Point(x=0, y=5), + ] + ), + ) + ], + ground_truths=[ + ObjectAnnotation( + name="dog", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=10, y=0), + Point(x=10, y=10), + Point(x=0, y=10), + ] + ), + ), + ObjectAnnotation( + name="cat", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=10, y=0), + Point(x=10, y=10), + Point(x=0, y=10), + ] + ), + ), + ], + expected={"dog": 0.25, "cat": 0.0}, + ), ] @@ -110,46 +143,60 @@ def one_overlap_class(): def empty_annotations(): return [ NameSpace(predictions=[], ground_truths=[], expected={}), - NameSpace(predictions=[], - ground_truths=[ - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])), - ObjectAnnotation(name="cat", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])) - ], - expected={ - 'dog': 0., - 'cat': 0. - }), - NameSpace(predictions=[ - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])), - ObjectAnnotation(name="cat", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])) - ], - ground_truths=[], - expected={ - 'dog': 0., - 'cat': 0. - }) + NameSpace( + predictions=[], + ground_truths=[ + ObjectAnnotation( + name="dog", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=10, y=0), + Point(x=10, y=10), + Point(x=0, y=10), + ] + ), + ), + ObjectAnnotation( + name="cat", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=10, y=0), + Point(x=10, y=10), + Point(x=0, y=10), + ] + ), + ), + ], + expected={"dog": 0.0, "cat": 0.0}, + ), + NameSpace( + predictions=[ + ObjectAnnotation( + name="dog", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=10, y=0), + Point(x=10, y=10), + Point(x=0, y=10), + ] + ), + ), + ObjectAnnotation( + name="cat", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=10, y=0), + Point(x=10, y=10), + Point(x=0, y=10), + ] + ), + ), + ], + ground_truths=[], + expected={"dog": 0.0, "cat": 0.0}, + ), ] diff --git a/libs/labelbox/tests/data/metrics/iou/feature/test_feature_iou.py b/libs/labelbox/tests/data/metrics/iou/feature/test_feature_iou.py index 653e485d1..da324c51b 100644 --- a/libs/labelbox/tests/data/metrics/iou/feature/test_feature_iou.py +++ b/libs/labelbox/tests/data/metrics/iou/feature/test_feature_iou.py @@ -19,7 +19,8 @@ def check_iou(pair): assert len(one_metrics) one_metric = one_metrics[0] assert one_metric.value == sum(list(pair.expected.values())) / len( - pair.expected) + pair.expected + ) def test_different_classes(different_classes): diff --git a/libs/labelbox/tests/data/serialization/coco/test_coco.py b/libs/labelbox/tests/data/serialization/coco/test_coco.py index 0113b555d..a7c733ce5 100644 --- a/libs/labelbox/tests/data/serialization/coco/test_coco.py +++ b/libs/labelbox/tests/data/serialization/coco/test_coco.py @@ -7,9 +7,10 @@ def run_instances(tmpdir): - instance_json = json.load(open(Path(COCO_ASSETS_DIR, 'instances.json'))) - res = COCOConverter.deserialize_instances(instance_json, - Path(COCO_ASSETS_DIR, 'images')) + instance_json = json.load(open(Path(COCO_ASSETS_DIR, "instances.json"))) + res = COCOConverter.deserialize_instances( + instance_json, Path(COCO_ASSETS_DIR, "images") + ) back = COCOConverter.serialize_instances( res, Path(tmpdir), @@ -17,18 +18,21 @@ def run_instances(tmpdir): def test_rle_objects(tmpdir): - rle_json = json.load(open(Path(COCO_ASSETS_DIR, 'rle.json'))) - res = COCOConverter.deserialize_instances(rle_json, - Path(COCO_ASSETS_DIR, 'images')) + rle_json = json.load(open(Path(COCO_ASSETS_DIR, "rle.json"))) + res = COCOConverter.deserialize_instances( + rle_json, Path(COCO_ASSETS_DIR, "images") + ) back = COCOConverter.serialize_instances(res, tmpdir) def test_panoptic(tmpdir): - panoptic_json = json.load(open(Path(COCO_ASSETS_DIR, 'panoptic.json'))) + panoptic_json = json.load(open(Path(COCO_ASSETS_DIR, "panoptic.json"))) image_dir, mask_dir = [ - Path(COCO_ASSETS_DIR, dir_name) for dir_name in ['images', 'masks'] + Path(COCO_ASSETS_DIR, dir_name) for dir_name in ["images", "masks"] ] res = COCOConverter.deserialize_panoptic(panoptic_json, image_dir, mask_dir) - back = COCOConverter.serialize_panoptic(res, - Path(f'/{tmpdir}/images_panoptic'), - Path(f'/{tmpdir}/masks_panoptic')) + back = COCOConverter.serialize_panoptic( + res, + Path(f"/{tmpdir}/images_panoptic"), + Path(f"/{tmpdir}/masks_panoptic"), + ) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_checklist.py b/libs/labelbox/tests/data/serialization/ndjson/test_checklist.py index c4b47427a..0bc3c8924 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_checklist.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_checklist.py @@ -1,5 +1,9 @@ from labelbox.data.annotation_types.annotation import ClassificationAnnotation -from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnswer, Radio +from labelbox.data.annotation_types.classification.classification import ( + Checklist, + ClassificationAnswer, + Radio, +) from labelbox.data.annotation_types.data.text import TextData from labelbox.data.annotation_types.label import Label @@ -17,18 +21,16 @@ def test_serialization_min(): ClassificationAnnotation( name="checkbox_question_geo", value=Checklist( - answer=[ClassificationAnswer(name="first_answer")]), + answer=[ClassificationAnswer(name="first_answer")] + ), ) - ]) + ], + ) expected = { - 'name': 'checkbox_question_geo', - 'dataRow': { - 'id': 'bkj7z2q0b0000jx6x0q2q7q0d' - }, - 'answer': [{ - 'name': 'first_answer' - }] + "name": "checkbox_question_geo", + "dataRow": {"id": "bkj7z2q0b0000jx6x0q2q7q0d"}, + "answer": [{"name": "first_answer"}], } serialized = NDJsonConverter.serialize([label]) res = next(serialized) @@ -54,61 +56,76 @@ def test_serialization_with_classification(): ClassificationAnnotation( name="checkbox_question_geo", confidence=0.5, - value=Checklist(answer=[ - ClassificationAnswer( - name="first_answer", - confidence=0.1, - classifications=[ - ClassificationAnnotation( - name="sub_radio_question", - value=Radio(answer=ClassificationAnswer( - name="first_sub_radio_answer", - confidence=0.31))), - ClassificationAnnotation( - name="sub_chck_question", - value=Checklist(answer=[ - ClassificationAnswer( - name="second_subchk_answer", - confidence=0.41), - ClassificationAnswer( - name="third_subchk_answer", - confidence=0.42), - ],)) - ]), - ])) - ]) + value=Checklist( + answer=[ + ClassificationAnswer( + name="first_answer", + confidence=0.1, + classifications=[ + ClassificationAnnotation( + name="sub_radio_question", + value=Radio( + answer=ClassificationAnswer( + name="first_sub_radio_answer", + confidence=0.31, + ) + ), + ), + ClassificationAnnotation( + name="sub_chck_question", + value=Checklist( + answer=[ + ClassificationAnswer( + name="second_subchk_answer", + confidence=0.41, + ), + ClassificationAnswer( + name="third_subchk_answer", + confidence=0.42, + ), + ], + ), + ), + ], + ), + ] + ), + ) + ], + ) expected = { - 'confidence': - 0.5, - 'name': - 'checkbox_question_geo', - 'dataRow': { - 'id': 'bkj7z2q0b0000jx6x0q2q7q0d' - }, - 'answer': [{ - 'confidence': - 0.1, - 'name': - 'first_answer', - 'classifications': [{ - 'name': 'sub_radio_question', - 'answer': { - 'confidence': 0.31, - 'name': 'first_sub_radio_answer', - } - }, { - 'name': - 'sub_chck_question', - 'answer': [{ - 'confidence': 0.41, - 'name': 'second_subchk_answer', - }, { - 'confidence': 0.42, - 'name': 'third_subchk_answer', - }] - }] - }] + "confidence": 0.5, + "name": "checkbox_question_geo", + "dataRow": {"id": "bkj7z2q0b0000jx6x0q2q7q0d"}, + "answer": [ + { + "confidence": 0.1, + "name": "first_answer", + "classifications": [ + { + "name": "sub_radio_question", + "answer": { + "confidence": 0.31, + "name": "first_sub_radio_answer", + }, + }, + { + "name": "sub_chck_question", + "answer": [ + { + "confidence": 0.41, + "name": "second_subchk_answer", + }, + { + "confidence": 0.42, + "name": "third_subchk_answer", + }, + ], + }, + ], + } + ], } serialized = NDJsonConverter.serialize([label]) @@ -119,7 +136,9 @@ def test_serialization_with_classification(): deserialized = NDJsonConverter.deserialize([res]) res = next(deserialized) - assert label.model_dump(exclude_none=True) == label.model_dump(exclude_none=True) + assert label.model_dump(exclude_none=True) == label.model_dump( + exclude_none=True + ) def test_serialization_with_classification_double_nested(): @@ -133,66 +152,80 @@ def test_serialization_with_classification_double_nested(): ClassificationAnnotation( name="checkbox_question_geo", confidence=0.5, - value=Checklist(answer=[ - ClassificationAnswer( - name="first_answer", - confidence=0.1, - classifications=[ - ClassificationAnnotation( - name="sub_radio_question", - value=Radio(answer=ClassificationAnswer( - name="first_sub_radio_answer", - confidence=0.31, - classifications=[ - ClassificationAnnotation( - name="sub_chck_question", - value=Checklist(answer=[ - ClassificationAnswer( - name="second_subchk_answer", - confidence=0.41), - ClassificationAnswer( - name="third_subchk_answer", - confidence=0.42), - ],)) - ]))), - ]), - ])) - ]) + value=Checklist( + answer=[ + ClassificationAnswer( + name="first_answer", + confidence=0.1, + classifications=[ + ClassificationAnnotation( + name="sub_radio_question", + value=Radio( + answer=ClassificationAnswer( + name="first_sub_radio_answer", + confidence=0.31, + classifications=[ + ClassificationAnnotation( + name="sub_chck_question", + value=Checklist( + answer=[ + ClassificationAnswer( + name="second_subchk_answer", + confidence=0.41, + ), + ClassificationAnswer( + name="third_subchk_answer", + confidence=0.42, + ), + ], + ), + ) + ], + ) + ), + ), + ], + ), + ] + ), + ) + ], + ) expected = { - 'confidence': - 0.5, - 'name': - 'checkbox_question_geo', - 'dataRow': { - 'id': 'bkj7z2q0b0000jx6x0q2q7q0d' - }, - 'answer': [{ - 'confidence': - 0.1, - 'name': - 'first_answer', - 'classifications': [{ - 'name': 'sub_radio_question', - 'answer': { - 'confidence': - 0.31, - 'name': - 'first_sub_radio_answer', - 'classifications': [{ - 'name': - 'sub_chck_question', - 'answer': [{ - 'confidence': 0.41, - 'name': 'second_subchk_answer', - }, { - 'confidence': 0.42, - 'name': 'third_subchk_answer', - }] - }] - } - }] - }] + "confidence": 0.5, + "name": "checkbox_question_geo", + "dataRow": {"id": "bkj7z2q0b0000jx6x0q2q7q0d"}, + "answer": [ + { + "confidence": 0.1, + "name": "first_answer", + "classifications": [ + { + "name": "sub_radio_question", + "answer": { + "confidence": 0.31, + "name": "first_sub_radio_answer", + "classifications": [ + { + "name": "sub_chck_question", + "answer": [ + { + "confidence": 0.41, + "name": "second_subchk_answer", + }, + { + "confidence": 0.42, + "name": "third_subchk_answer", + }, + ], + } + ], + }, + } + ], + } + ], } serialized = NDJsonConverter.serialize([label]) res = next(serialized) @@ -203,7 +236,9 @@ def test_serialization_with_classification_double_nested(): deserialized = NDJsonConverter.deserialize([res]) res = next(deserialized) res.annotations[0].extra.pop("uuid") - assert label.model_dump(exclude_none=True) == label.model_dump(exclude_none=True) + assert label.model_dump(exclude_none=True) == label.model_dump( + exclude_none=True + ) def test_serialization_with_classification_double_nested_2(): @@ -216,62 +251,79 @@ def test_serialization_with_classification_double_nested_2(): annotations=[ ClassificationAnnotation( name="sub_radio_question", - value=Radio(answer=ClassificationAnswer( - name="first_sub_radio_answer", - confidence=0.31, - classifications=[ - ClassificationAnnotation( - name="sub_chck_question", - value=Checklist(answer=[ - ClassificationAnswer( - name="second_subchk_answer", - confidence=0.41, - classifications=[ - ClassificationAnnotation( - name="checkbox_question_geo", - value=Checklist(answer=[ - ClassificationAnswer( - name="first_answer", - confidence=0.1), - ])) - ]), - ClassificationAnswer(name="third_subchk_answer", - confidence=0.42), - ])) - ]))), - ]) + value=Radio( + answer=ClassificationAnswer( + name="first_sub_radio_answer", + confidence=0.31, + classifications=[ + ClassificationAnnotation( + name="sub_chck_question", + value=Checklist( + answer=[ + ClassificationAnswer( + name="second_subchk_answer", + confidence=0.41, + classifications=[ + ClassificationAnnotation( + name="checkbox_question_geo", + value=Checklist( + answer=[ + ClassificationAnswer( + name="first_answer", + confidence=0.1, + ), + ] + ), + ) + ], + ), + ClassificationAnswer( + name="third_subchk_answer", + confidence=0.42, + ), + ] + ), + ) + ], + ) + ), + ), + ], + ) expected = { - 'name': 'sub_radio_question', - 'answer': { - 'confidence': - 0.31, - 'name': - 'first_sub_radio_answer', - 'classifications': [{ - 'name': - 'sub_chck_question', - 'answer': [{ - 'confidence': - 0.41, - 'name': - 'second_subchk_answer', - 'classifications': [{ - 'name': 'checkbox_question_geo', - 'answer': [{ - 'confidence': 0.1, - 'name': 'first_answer', - }] - }] - }, { - 'confidence': 0.42, - 'name': 'third_subchk_answer', - }] - }] + "name": "sub_radio_question", + "answer": { + "confidence": 0.31, + "name": "first_sub_radio_answer", + "classifications": [ + { + "name": "sub_chck_question", + "answer": [ + { + "confidence": 0.41, + "name": "second_subchk_answer", + "classifications": [ + { + "name": "checkbox_question_geo", + "answer": [ + { + "confidence": 0.1, + "name": "first_answer", + } + ], + } + ], + }, + { + "confidence": 0.42, + "name": "third_subchk_answer", + }, + ], + } + ], }, - 'dataRow': { - 'id': 'bkj7z2q0b0000jx6x0q2q7q0d' - } + "dataRow": {"id": "bkj7z2q0b0000jx6x0q2q7q0d"}, } serialized = NDJsonConverter.serialize([label]) @@ -281,4 +333,6 @@ def test_serialization_with_classification_double_nested_2(): deserialized = NDJsonConverter.deserialize([res]) res = next(deserialized) - assert label.model_dump(exclude_none=True) == label.model_dump(exclude_none=True) + assert label.model_dump(exclude_none=True) == label.model_dump( + exclude_none=True + ) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_classification.py b/libs/labelbox/tests/data/serialization/ndjson/test_classification.py index 00a684b20..8dcb17f0b 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_classification.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_classification.py @@ -4,8 +4,9 @@ def test_classification(): - with open('tests/data/assets/ndjson/classification_import.json', - 'r') as file: + with open( + "tests/data/assets/ndjson/classification_import.json", "r" + ) as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) @@ -13,8 +14,9 @@ def test_classification(): def test_classification_with_name(): - with open('tests/data/assets/ndjson/classification_import_name_only.json', - 'r') as file: + with open( + "tests/data/assets/ndjson/classification_import_name_only.json", "r" + ) as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_conversation.py b/libs/labelbox/tests/data/serialization/ndjson/test_conversation.py index 4d2a0416c..f7da9181b 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_conversation.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_conversation.py @@ -4,99 +4,117 @@ import labelbox.types as lb_types from labelbox.data.serialization.ndjson.converter import NDJsonConverter -radio_ndjson = [{ - 'dataRow': { - 'globalKey': 'my_global_key' - }, - 'name': 'radio', - 'answer': { - 'name': 'first_radio_answer' - }, - 'messageId': '0' -}] +radio_ndjson = [ + { + "dataRow": {"globalKey": "my_global_key"}, + "name": "radio", + "answer": {"name": "first_radio_answer"}, + "messageId": "0", + } +] radio_label = [ lb_types.Label( - data=lb_types.ConversationData(global_key='my_global_key'), + data=lb_types.ConversationData(global_key="my_global_key"), annotations=[ lb_types.ClassificationAnnotation( - name='radio', - value=lb_types.Radio(answer=lb_types.ClassificationAnswer( - name="first_radio_answer")), - message_id="0") - ]) + name="radio", + value=lb_types.Radio( + answer=lb_types.ClassificationAnswer( + name="first_radio_answer" + ) + ), + message_id="0", + ) + ], + ) ] -checklist_ndjson = [{ - 'dataRow': { - 'globalKey': 'my_global_key' - }, - 'name': 'checklist', - 'answer': [ - { - 'name': 'first_checklist_answer' - }, - { - 'name': 'second_checklist_answer' - }, - ], - 'messageId': '2' -}] +checklist_ndjson = [ + { + "dataRow": {"globalKey": "my_global_key"}, + "name": "checklist", + "answer": [ + {"name": "first_checklist_answer"}, + {"name": "second_checklist_answer"}, + ], + "messageId": "2", + } +] checklist_label = [ - lb_types.Label(data=lb_types.ConversationData(global_key='my_global_key'), - annotations=[ - lb_types.ClassificationAnnotation( - name='checklist', - message_id="2", - value=lb_types.Checklist(answer=[ - lb_types.ClassificationAnswer( - name="first_checklist_answer"), - lb_types.ClassificationAnswer( - name="second_checklist_answer") - ])) - ]) + lb_types.Label( + data=lb_types.ConversationData(global_key="my_global_key"), + annotations=[ + lb_types.ClassificationAnnotation( + name="checklist", + message_id="2", + value=lb_types.Checklist( + answer=[ + lb_types.ClassificationAnswer( + name="first_checklist_answer" + ), + lb_types.ClassificationAnswer( + name="second_checklist_answer" + ), + ] + ), + ) + ], + ) ] -free_text_ndjson = [{ - 'dataRow': { - 'globalKey': 'my_global_key' - }, - 'name': 'free_text', - 'answer': 'sample text', - 'messageId': '0' -}] +free_text_ndjson = [ + { + "dataRow": {"globalKey": "my_global_key"}, + "name": "free_text", + "answer": "sample text", + "messageId": "0", + } +] free_text_label = [ - lb_types.Label(data=lb_types.ConversationData(global_key='my_global_key'), - annotations=[ - lb_types.ClassificationAnnotation( - name='free_text', - message_id="0", - value=lb_types.Text(answer="sample text")) - ]) + lb_types.Label( + data=lb_types.ConversationData(global_key="my_global_key"), + annotations=[ + lb_types.ClassificationAnnotation( + name="free_text", + message_id="0", + value=lb_types.Text(answer="sample text"), + ) + ], + ) ] @pytest.mark.parametrize( "label, ndjson", - [[radio_label, radio_ndjson], [checklist_label, checklist_ndjson], - [free_text_label, free_text_ndjson]]) + [ + [radio_label, radio_ndjson], + [checklist_label, checklist_ndjson], + [free_text_label, free_text_ndjson], + ], +) def test_message_based_radio_classification(label, ndjson): serialized_label = list(NDJsonConverter().serialize(label)) - serialized_label[0].pop('uuid') + serialized_label[0].pop("uuid") assert serialized_label == ndjson deserialized_label = list(NDJsonConverter().deserialize(ndjson)) - deserialized_label[0].annotations[0].extra.pop('uuid') - assert deserialized_label[0].model_dump(exclude_none=True) == label[0].model_dump(exclude_none=True) + deserialized_label[0].annotations[0].extra.pop("uuid") + assert deserialized_label[0].model_dump(exclude_none=True) == label[ + 0 + ].model_dump(exclude_none=True) -@pytest.mark.parametrize("filename", [ - "tests/data/assets/ndjson/conversation_entity_import.json", - "tests/data/assets/ndjson/conversation_entity_without_confidence_import.json" -]) +@pytest.mark.parametrize( + "filename", + [ + "tests/data/assets/ndjson/conversation_entity_import.json", + "tests/data/assets/ndjson/conversation_entity_without_confidence_import.json", + ], +) def test_conversation_entity_import(filename: str): - with open(filename, 'r') as file: + with open(filename, "r") as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) @@ -104,30 +122,34 @@ def test_conversation_entity_import(filename: str): def test_benchmark_reference_label_flag_enabled(): - label = lb_types.Label(data=lb_types.ConversationData(global_key='my_global_key'), - annotations=[ - lb_types.ClassificationAnnotation( - name='free_text', - message_id="0", - value=lb_types.Text(answer="sample text")) - ], - is_benchmark_reference=True - ) + label = lb_types.Label( + data=lb_types.ConversationData(global_key="my_global_key"), + annotations=[ + lb_types.ClassificationAnnotation( + name="free_text", + message_id="0", + value=lb_types.Text(answer="sample text"), + ) + ], + is_benchmark_reference=True, + ) res = list(NDJsonConverter.serialize([label])) assert res[0]["isBenchmarkReferenceLabel"] def test_benchmark_reference_label_flag_disabled(): - label = lb_types.Label(data=lb_types.ConversationData(global_key='my_global_key'), - annotations=[ - lb_types.ClassificationAnnotation( - name='free_text', - message_id="0", - value=lb_types.Text(answer="sample text")) - ], - is_benchmark_reference=False - ) + label = lb_types.Label( + data=lb_types.ConversationData(global_key="my_global_key"), + annotations=[ + lb_types.ClassificationAnnotation( + name="free_text", + message_id="0", + value=lb_types.Text(answer="sample text"), + ) + ], + is_benchmark_reference=False, + ) res = list(NDJsonConverter.serialize([label])) assert not res[0].get("isBenchmarkReferenceLabel") diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py b/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py index 186c75223..333c00250 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py @@ -2,35 +2,41 @@ import pytest import labelbox.types as lb_types from labelbox.data.serialization import NDJsonConverter -from labelbox.data.serialization.ndjson.objects import NDDicomSegments, NDDicomSegment, NDDicomLine +from labelbox.data.serialization.ndjson.objects import ( + NDDicomSegments, + NDDicomSegment, + NDDicomLine, +) + """ Data gen prompt test data """ prompt_text_annotation = lb_types.PromptClassificationAnnotation( - feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", - name="test", - value=lb_types.PromptText(answer="the answer to the text questions right here"), - ) - -prompt_text_ndjson = { - "answer": "the answer to the text questions right here", - "name": "test", - "schemaId": "ckrb1sfkn099c0y910wbo0p1a", - "dataRow": { - "id": "ckrb1sf1i1g7i0ybcdc6oc8ct" - }, - } + feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", + name="test", + value=lb_types.PromptText( + answer="the answer to the text questions right here" + ), +) + +prompt_text_ndjson = { + "answer": "the answer to the text questions right here", + "name": "test", + "schemaId": "ckrb1sfkn099c0y910wbo0p1a", + "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, +} data_gen_label = lb_types.Label( data={"uid": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, - annotations=[prompt_text_annotation] + annotations=[prompt_text_annotation], ) - + """ Prompt annotation test """ + def test_serialize_label(): serialized_label = next(NDJsonConverter().serialize([data_gen_label])) # Remove uuid field since this is a random value that can not be specified also meant for relationships @@ -39,17 +45,23 @@ def test_serialize_label(): def test_deserialize_label(): - deserialized_label = next(NDJsonConverter().deserialize([prompt_text_ndjson])) - if hasattr(deserialized_label.annotations[0], 'extra'): + deserialized_label = next( + NDJsonConverter().deserialize([prompt_text_ndjson]) + ) + if hasattr(deserialized_label.annotations[0], "extra"): # Extra fields are added to deserialized label by default need removed to match deserialized_label.annotations[0].extra = {} - assert deserialized_label.model_dump(exclude_none=True) == data_gen_label.model_dump(exclude_none=True) + assert deserialized_label.model_dump( + exclude_none=True + ) == data_gen_label.model_dump(exclude_none=True) def test_serialize_deserialize_label(): serialized = list(NDJsonConverter.serialize([data_gen_label])) deserialized = next(NDJsonConverter.deserialize(serialized)) - if hasattr(deserialized.annotations[0], 'extra'): + if hasattr(deserialized.annotations[0], "extra"): # Extra fields are added to deserialized label by default need removed to match deserialized.annotations[0].extra = {} - assert deserialized.model_dump(exclude_none=True) == data_gen_label.model_dump(exclude_none=True) + assert deserialized.model_dump( + exclude_none=True + ) == data_gen_label.model_dump(exclude_none=True) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_dicom.py b/libs/labelbox/tests/data/serialization/ndjson/test_dicom.py index e69c21bae..633214367 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_dicom.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_dicom.py @@ -3,127 +3,120 @@ import base64 import labelbox.types as lb_types from labelbox.data.serialization import NDJsonConverter -from labelbox.data.serialization.ndjson.objects import NDDicomSegments, NDDicomSegment, NDDicomLine +from labelbox.data.serialization.ndjson.objects import ( + NDDicomSegments, + NDDicomSegment, + NDDicomLine, +) + """ Polyline test data """ dicom_polyline_annotations = [ - lb_types.DICOMObjectAnnotation(uuid="78a8a027-9089-420c-8348-6099eb77e4aa", - name="dicom_polyline", - frame=2, - value=lb_types.Line(points=[ - lb_types.Point(x=680, y=100), - lb_types.Point(x=100, y=190), - lb_types.Point(x=190, y=220) - ]), - segment_index=0, - keyframe=True, - group_key=lb_types.GroupKey.AXIAL) + lb_types.DICOMObjectAnnotation( + uuid="78a8a027-9089-420c-8348-6099eb77e4aa", + name="dicom_polyline", + frame=2, + value=lb_types.Line( + points=[ + lb_types.Point(x=680, y=100), + lb_types.Point(x=100, y=190), + lb_types.Point(x=190, y=220), + ] + ), + segment_index=0, + keyframe=True, + group_key=lb_types.GroupKey.AXIAL, + ) ] -polyline_label = lb_types.Label(data=lb_types.DicomData(uid="test-uid"), - annotations=dicom_polyline_annotations) +polyline_label = lb_types.Label( + data=lb_types.DicomData(uid="test-uid"), + annotations=dicom_polyline_annotations, +) polyline_annotation_ndjson = { - 'classifications': [], - 'dataRow': { - 'id': 'test-uid' - }, - 'name': - 'dicom_polyline', - 'groupKey': - 'axial', - 'segments': [{ - 'keyframes': [{ - 'frame': 2, - 'line': [ - { - 'x': 680.0, - 'y': 100.0 - }, + "classifications": [], + "dataRow": {"id": "test-uid"}, + "name": "dicom_polyline", + "groupKey": "axial", + "segments": [ + { + "keyframes": [ { - 'x': 100.0, - 'y': 190.0 - }, - { - 'x': 190.0, - 'y': 220.0 - }, - ], - 'classifications': [], - }] - }], + "frame": 2, + "line": [ + {"x": 680.0, "y": 100.0}, + {"x": 100.0, "y": 190.0}, + {"x": 190.0, "y": 220.0}, + ], + "classifications": [], + } + ] + } + ], } polyline_with_global_key = lb_types.Label( data=lb_types.DicomData(global_key="test-global-key"), - annotations=dicom_polyline_annotations) + annotations=dicom_polyline_annotations, +) polyline_annotation_ndjson_with_global_key = copy(polyline_annotation_ndjson) -polyline_annotation_ndjson_with_global_key['dataRow'] = { - 'globalKey': 'test-global-key' +polyline_annotation_ndjson_with_global_key["dataRow"] = { + "globalKey": "test-global-key" } """ Video test data """ -instance_uri_1 = 'https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA' -instance_uri_5 = 'https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys1%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA' +instance_uri_1 = "https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA" +instance_uri_5 = "https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys1%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA" frames = [ lb_types.MaskFrame(index=1, instance_uri=instance_uri_1), - lb_types.MaskFrame(index=5, instance_uri=instance_uri_5) + lb_types.MaskFrame(index=5, instance_uri=instance_uri_5), ] instances = [ lb_types.MaskInstance(color_rgb=(0, 0, 255), name="mask1"), lb_types.MaskInstance(color_rgb=(0, 255, 0), name="mask2"), - lb_types.MaskInstance(color_rgb=(255, 0, 0), name="mask3") + lb_types.MaskInstance(color_rgb=(255, 0, 0), name="mask3"), ] -video_mask_annotation = lb_types.VideoMaskAnnotation(frames=frames, - instances=instances) +video_mask_annotation = lb_types.VideoMaskAnnotation( + frames=frames, instances=instances +) video_mask_annotation_ndjson = { - 'dataRow': { - 'id': 'test-uid' - }, - 'masks': { - 'frames': [{ - 'index': 1, - 'instanceURI': instance_uri_1 - }, { - 'index': 5, - 'instanceURI': instance_uri_5 - }], - 'instances': [ - { - 'colorRGB': (0, 0, 255), - 'name': 'mask1' - }, - { - 'colorRGB': (0, 255, 0), - 'name': 'mask2' - }, - { - 'colorRGB': (255, 0, 0), - 'name': 'mask3' - }, - ] + "dataRow": {"id": "test-uid"}, + "masks": { + "frames": [ + {"index": 1, "instanceURI": instance_uri_1}, + {"index": 5, "instanceURI": instance_uri_5}, + ], + "instances": [ + {"colorRGB": (0, 0, 255), "name": "mask1"}, + {"colorRGB": (0, 255, 0), "name": "mask2"}, + {"colorRGB": (255, 0, 0), "name": "mask3"}, + ], }, } video_mask_annotation_ndjson_with_global_key = copy( - video_mask_annotation_ndjson) -video_mask_annotation_ndjson_with_global_key['dataRow'] = { - 'globalKey': 'test-global-key' + video_mask_annotation_ndjson +) +video_mask_annotation_ndjson_with_global_key["dataRow"] = { + "globalKey": "test-global-key" } -video_mask_label = lb_types.Label(data=lb_types.VideoData(uid="test-uid"), - annotations=[video_mask_annotation]) +video_mask_label = lb_types.Label( + data=lb_types.VideoData(uid="test-uid"), annotations=[video_mask_annotation] +) video_mask_label_with_global_key = lb_types.Label( data=lb_types.VideoData(global_key="test-global-key"), - annotations=[video_mask_annotation]) + annotations=[video_mask_annotation], +) """ DICOM Mask test data """ @@ -132,30 +125,37 @@ name="dicom_mask", group_key=lb_types.GroupKey.AXIAL, frames=frames, - instances=instances) + instances=instances, +) -dicom_mask_label = lb_types.Label(data=lb_types.DicomData(uid="test-uid"), - annotations=[dicom_mask_annotation]) +dicom_mask_label = lb_types.Label( + data=lb_types.DicomData(uid="test-uid"), annotations=[dicom_mask_annotation] +) dicom_mask_label_with_global_key = lb_types.Label( data=lb_types.DicomData(global_key="test-global-key"), - annotations=[dicom_mask_annotation]) + annotations=[dicom_mask_annotation], +) dicom_mask_annotation_ndjson = copy(video_mask_annotation_ndjson) -dicom_mask_annotation_ndjson['groupKey'] = 'axial' +dicom_mask_annotation_ndjson["groupKey"] = "axial" dicom_mask_annotation_ndjson_with_global_key = copy( - dicom_mask_annotation_ndjson) -dicom_mask_annotation_ndjson_with_global_key['dataRow'] = { - 'globalKey': 'test-global-key' + dicom_mask_annotation_ndjson +) +dicom_mask_annotation_ndjson_with_global_key["dataRow"] = { + "globalKey": "test-global-key" } """ Tests """ labels = [ - polyline_label, polyline_with_global_key, dicom_mask_label, - dicom_mask_label_with_global_key, video_mask_label, - video_mask_label_with_global_key + polyline_label, + polyline_with_global_key, + dicom_mask_label, + dicom_mask_label_with_global_key, + video_mask_label, + video_mask_label_with_global_key, ] ndjsons = [ polyline_annotation_ndjson, @@ -175,32 +175,31 @@ def test_deserialize_nd_dicom_segments(): assert isinstance(nd_dicom_segments.segments[0].keyframes[0], NDDicomLine) -@pytest.mark.parametrize('label, ndjson', labels_ndjsons) +@pytest.mark.parametrize("label, ndjson", labels_ndjsons) def test_serialize_label(label, ndjson): serialized_label = next(NDJsonConverter().serialize([label])) if "uuid" in serialized_label: - serialized_label.pop('uuid') + serialized_label.pop("uuid") assert serialized_label == ndjson -@pytest.mark.parametrize('label, ndjson', labels_ndjsons) +@pytest.mark.parametrize("label, ndjson", labels_ndjsons) def test_deserialize_label(label, ndjson): deserialized_label = next(NDJsonConverter().deserialize([ndjson])) - if hasattr(deserialized_label.annotations[0], 'extra'): + if hasattr(deserialized_label.annotations[0], "extra"): deserialized_label.annotations[0].extra = {} for i, annotation in enumerate(deserialized_label.annotations): if hasattr(annotation, "frames"): assert annotation.frames == label.annotations[i].frames if hasattr(annotation, "value"): assert annotation.value == label.annotations[i].value - -@pytest.mark.parametrize('label', labels) +@pytest.mark.parametrize("label", labels) def test_serialize_deserialize_label(label): serialized = list(NDJsonConverter.serialize([label])) deserialized = list(NDJsonConverter.deserialize(serialized)) - if hasattr(deserialized[0].annotations[0], 'extra'): + if hasattr(deserialized[0].annotations[0], "extra"): deserialized[0].annotations[0].extra = {} for i, annotation in enumerate(deserialized[0].annotations): if hasattr(annotation, "frames"): diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_document.py b/libs/labelbox/tests/data/serialization/ndjson/test_document.py index cdfbbbb88..5fe6a9789 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_document.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_document.py @@ -8,26 +8,30 @@ start=lb_types.Point(x=42.799, y=86.498), # Top left end=lb_types.Point(x=141.911, y=303.195), # Bottom right page=1, - unit=lb_types.RectangleUnit.POINTS)) + unit=lb_types.RectangleUnit.POINTS, + ), +) bbox_labels = [ - lb_types.Label(data=lb_types.DocumentData(global_key='test-global-key'), - annotations=[bbox_annotation]) + lb_types.Label( + data=lb_types.DocumentData(global_key="test-global-key"), + annotations=[bbox_annotation], + ) +] +bbox_ndjson = [ + { + "bbox": { + "height": 216.697, + "left": 42.799, + "top": 86.498, + "width": 99.112, + }, + "classifications": [], + "dataRow": {"globalKey": "test-global-key"}, + "name": "bounding_box", + "page": 1, + "unit": "POINTS", + } ] -bbox_ndjson = [{ - 'bbox': { - 'height': 216.697, - 'left': 42.799, - 'top': 86.498, - 'width': 99.112, - }, - 'classifications': [], - 'dataRow': { - 'globalKey': 'test-global-key' - }, - 'name': 'bounding_box', - 'page': 1, - 'unit': 'POINTS' -}] def round_dict(data): @@ -47,7 +51,7 @@ def test_pdf(): """ Tests a pdf file with bbox annotations only """ - with open('tests/data/assets/ndjson/pdf_import.json', 'r') as f: + with open("tests/data/assets/ndjson/pdf_import.json", "r") as f: data = json.load(f) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) @@ -59,7 +63,7 @@ def test_pdf_with_name_only(): """ Tests a pdf file with bbox annotations only """ - with open('tests/data/assets/ndjson/pdf_import_name_only.json', 'r') as f: + with open("tests/data/assets/ndjson/pdf_import_name_only.json", "r") as f: data = json.load(f) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) @@ -69,12 +73,18 @@ def test_pdf_with_name_only(): def test_pdf_bbox_serialize(): serialized = list(NDJsonConverter.serialize(bbox_labels)) - serialized[0].pop('uuid') + serialized[0].pop("uuid") assert serialized == bbox_ndjson def test_pdf_bbox_deserialize(): deserialized = list(NDJsonConverter.deserialize(bbox_ndjson)) deserialized[0].annotations[0].extra = {} - assert deserialized[0].annotations[0].value == bbox_labels[0].annotations[0].value - assert deserialized[0].annotations[0].name == bbox_labels[0].annotations[0].name \ No newline at end of file + assert ( + deserialized[0].annotations[0].value + == bbox_labels[0].annotations[0].value + ) + assert ( + deserialized[0].annotations[0].name + == bbox_labels[0].annotations[0].name + ) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_export_video_objects.py b/libs/labelbox/tests/data/serialization/ndjson/test_export_video_objects.py index c85b48234..4adcd9935 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_export_video_objects.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_export_video_objects.py @@ -6,588 +6,580 @@ def video_bbox_label(): return Label( - uid='cl1z52xwh00050fhcmfgczqvn', + uid="cl1z52xwh00050fhcmfgczqvn", data=VideoData( uid="cklr9mr4m5iao0rb6cvxu4qbn", file_path=None, frames=None, - url= - "https://storage.labelbox.com/ckcz6bubudyfi0855o1dt1g9s%2F26403a22-604a-a38c-eeff-c2ed481fb40a-cat.mp4?Expires=1651677421050&KeyName=labelbox-assets-key-3&Signature=vF7gMyfHzgZdfbB8BHgd88Ws-Ms" + url="https://storage.labelbox.com/ckcz6bubudyfi0855o1dt1g9s%2F26403a22-604a-a38c-eeff-c2ed481fb40a-cat.mp4?Expires=1651677421050&KeyName=labelbox-assets-key-3&Signature=vF7gMyfHzgZdfbB8BHgd88Ws-Ms", ), annotations=[ - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=46.0), - end=Point(extra={}, - x=454.0, - y=295.0)), - classifications=[], - frame=1, - keyframe=True), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=42.5), - end=Point(extra={}, - x=427.25, - y=308.25)), - classifications=[], - frame=2, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=39.0), - end=Point(extra={}, - x=400.5, - y=321.5)), - classifications=[], - frame=3, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=35.5), - end=Point(extra={}, - x=373.75, - y=334.75)), - classifications=[], - frame=4, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=32.0), - end=Point(extra={}, - x=347.0, - y=348.0)), - classifications=[], - frame=5, - keyframe=True), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=132.0), - end=Point(extra={}, - x=283.0, - y=348.0)), - classifications=[], - frame=9, - keyframe=True), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=122.333), - end=Point(extra={}, - x=295.5, - y=348.0)), - classifications=[], - frame=10, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=112.667), - end=Point(extra={}, - x=308.0, - y=348.0)), - classifications=[], - frame=11, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=103.0), - end=Point(extra={}, - x=320.5, - y=348.0)), - classifications=[], - frame=12, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=93.333), - end=Point(extra={}, - x=333.0, - y=348.0)), - classifications=[], - frame=13, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=83.667), - end=Point(extra={}, - x=345.5, - y=348.0)), - classifications=[], - frame=14, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=74.0), - end=Point(extra={}, - x=358.0, - y=348.0)), - classifications=[], - frame=15, - keyframe=True), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=66.833), - end=Point(extra={}, - x=387.333, - y=348.0)), - classifications=[], - frame=16, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=59.667), - end=Point(extra={}, - x=416.667, - y=348.0)), - classifications=[], - frame=17, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=52.5), - end=Point(extra={}, - x=446.0, - y=348.0)), - classifications=[], - frame=18, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=45.333), - end=Point(extra={}, - x=475.333, - y=348.0)), - classifications=[], - frame=19, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=38.167), - end=Point(extra={}, - x=504.667, - y=348.0)), - classifications=[], - frame=20, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=31.0), - end=Point(extra={}, - x=534.0, - y=348.0)), - classifications=[], - frame=21, - keyframe=True), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=29.5), - end=Point(extra={}, - x=543.0, - y=348.0)), - classifications=[], - frame=22, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=28.0), - end=Point(extra={}, - x=552.0, - y=348.0)), - classifications=[], - frame=23, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=26.5), - end=Point(extra={}, - x=561.0, - y=348.0)), - classifications=[], - frame=24, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=25.0), - end=Point(extra={}, - x=570.0, - y=348.0)), - classifications=[], - frame=25, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=23.5), - end=Point(extra={}, - x=579.0, - y=348.0)), - classifications=[], - frame=26, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=22.0), - end=Point(extra={}, - x=588.0, - y=348.0)), - classifications=[], - frame=27, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=20.5), - end=Point(extra={}, - x=597.0, - y=348.0)), - classifications=[], - frame=28, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=19.0), - end=Point(extra={}, - x=606.0, - y=348.0)), - classifications=[], - frame=29, - keyframe=True) - ], - extra={ - 'Created By': - 'jtso@labelbox.com', - 'Project Name': - 'Pictor Video', - 'Created At': - '2022-04-14T15:11:19.000Z', - 'Updated At': - '2022-04-14T15:11:21.064Z', - 'Seconds to Label': - 0.0, - 'Agreement': - -1.0, - 'Benchmark Agreement': - -1.0, - 'Benchmark ID': - None, - 'Dataset Name': - 'cat', - 'Reviews': [], - 'View Label': - 'https://editor.labelbox.com?project=ckz38nsfd0lzq109bhq73est1&label=cl1z52xwh00050fhcmfgczqvn', - 'Has Open Issues': - 0.0, - 'Skipped': - False, - 'media_type': - 'video', - 'Data Split': - None - }) - - -def video_serialized_bbox_label(): - return { - 'uuid': - 'b24e672b-8f79-4d96-bf5e-b552ca0820d5', - 'dataRow': { - 'id': 'cklr9mr4m5iao0rb6cvxu4qbn' - }, - 'schemaId': - 'ckz38ofop0mci0z9i9w3aa9o4', - 'name': - 'bbox toy', - 'classifications': [], - 'segments': [{ - 'keyframes': [{ - 'frame': 1, - 'bbox': { - 'top': 46.0, - 'left': 70.0, - 'height': 249.0, - 'width': 384.0 + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=46.0), + end=Point(extra={}, x=454.0, y=295.0), + ), + classifications=[], + frame=1, + keyframe=True, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=42.5), + end=Point(extra={}, x=427.25, y=308.25), + ), + classifications=[], + frame=2, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=39.0), + end=Point(extra={}, x=400.5, y=321.5), + ), + classifications=[], + frame=3, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=35.5), + end=Point(extra={}, x=373.75, y=334.75), + ), + classifications=[], + frame=4, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=32.0), + end=Point(extra={}, x=347.0, y=348.0), + ), + classifications=[], + frame=5, + keyframe=True, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=132.0), + end=Point(extra={}, x=283.0, y=348.0), + ), + classifications=[], + frame=9, + keyframe=True, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=122.333), + end=Point(extra={}, x=295.5, y=348.0), + ), + classifications=[], + frame=10, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=112.667), + end=Point(extra={}, x=308.0, y=348.0), + ), + classifications=[], + frame=11, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=103.0), + end=Point(extra={}, x=320.5, y=348.0), + ), + classifications=[], + frame=12, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=93.333), + end=Point(extra={}, x=333.0, y=348.0), + ), + classifications=[], + frame=13, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=83.667), + end=Point(extra={}, x=345.5, y=348.0), + ), + classifications=[], + frame=14, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", }, - 'classifications': [] - }, { - 'frame': 5, - 'bbox': { - 'top': 32.0, - 'left': 70.0, - 'height': 316.0, - 'width': 277.0 + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=74.0), + end=Point(extra={}, x=358.0, y=348.0), + ), + classifications=[], + frame=15, + keyframe=True, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", }, - 'classifications': [] - }] - }, { - 'keyframes': [{ - 'frame': 9, - 'bbox': { - 'top': 132.0, - 'left': 70.0, - 'height': 216.0, - 'width': 213.0 + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=66.833), + end=Point(extra={}, x=387.333, y=348.0), + ), + classifications=[], + frame=16, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", }, - 'classifications': [] - }, { - 'frame': 15, - 'bbox': { - 'top': 74.0, - 'left': 70.0, - 'height': 274.0, - 'width': 288.0 + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=59.667), + end=Point(extra={}, x=416.667, y=348.0), + ), + classifications=[], + frame=17, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", }, - 'classifications': [] - }, { - 'frame': 21, - 'bbox': { - 'top': 31.0, - 'left': 70.0, - 'height': 317.0, - 'width': 464.0 + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=52.5), + end=Point(extra={}, x=446.0, y=348.0), + ), + classifications=[], + frame=18, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", }, - 'classifications': [] - }, { - 'frame': 29, - 'bbox': { - 'top': 19.0, - 'left': 70.0, - 'height': 329.0, - 'width': 536.0 + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=45.333), + end=Point(extra={}, x=475.333, y=348.0), + ), + classifications=[], + frame=19, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", }, - 'classifications': [] - }] - }] + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=38.167), + end=Point(extra={}, x=504.667, y=348.0), + ), + classifications=[], + frame=20, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=31.0), + end=Point(extra={}, x=534.0, y=348.0), + ), + classifications=[], + frame=21, + keyframe=True, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=29.5), + end=Point(extra={}, x=543.0, y=348.0), + ), + classifications=[], + frame=22, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=28.0), + end=Point(extra={}, x=552.0, y=348.0), + ), + classifications=[], + frame=23, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=26.5), + end=Point(extra={}, x=561.0, y=348.0), + ), + classifications=[], + frame=24, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=25.0), + end=Point(extra={}, x=570.0, y=348.0), + ), + classifications=[], + frame=25, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=23.5), + end=Point(extra={}, x=579.0, y=348.0), + ), + classifications=[], + frame=26, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=22.0), + end=Point(extra={}, x=588.0, y=348.0), + ), + classifications=[], + frame=27, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=20.5), + end=Point(extra={}, x=597.0, y=348.0), + ), + classifications=[], + frame=28, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=19.0), + end=Point(extra={}, x=606.0, y=348.0), + ), + classifications=[], + frame=29, + keyframe=True, + ), + ], + extra={ + "Created By": "jtso@labelbox.com", + "Project Name": "Pictor Video", + "Created At": "2022-04-14T15:11:19.000Z", + "Updated At": "2022-04-14T15:11:21.064Z", + "Seconds to Label": 0.0, + "Agreement": -1.0, + "Benchmark Agreement": -1.0, + "Benchmark ID": None, + "Dataset Name": "cat", + "Reviews": [], + "View Label": "https://editor.labelbox.com?project=ckz38nsfd0lzq109bhq73est1&label=cl1z52xwh00050fhcmfgczqvn", + "Has Open Issues": 0.0, + "Skipped": False, + "media_type": "video", + "Data Split": None, + }, + ) + + +def video_serialized_bbox_label(): + return { + "uuid": "b24e672b-8f79-4d96-bf5e-b552ca0820d5", + "dataRow": {"id": "cklr9mr4m5iao0rb6cvxu4qbn"}, + "schemaId": "ckz38ofop0mci0z9i9w3aa9o4", + "name": "bbox toy", + "classifications": [], + "segments": [ + { + "keyframes": [ + { + "frame": 1, + "bbox": { + "top": 46.0, + "left": 70.0, + "height": 249.0, + "width": 384.0, + }, + "classifications": [], + }, + { + "frame": 5, + "bbox": { + "top": 32.0, + "left": 70.0, + "height": 316.0, + "width": 277.0, + }, + "classifications": [], + }, + ] + }, + { + "keyframes": [ + { + "frame": 9, + "bbox": { + "top": 132.0, + "left": 70.0, + "height": 216.0, + "width": 213.0, + }, + "classifications": [], + }, + { + "frame": 15, + "bbox": { + "top": 74.0, + "left": 70.0, + "height": 274.0, + "width": 288.0, + }, + "classifications": [], + }, + { + "frame": 21, + "bbox": { + "top": 31.0, + "left": 70.0, + "height": 317.0, + "width": 464.0, + }, + "classifications": [], + }, + { + "frame": 29, + "bbox": { + "top": 19.0, + "left": 70.0, + "height": 329.0, + "width": 536.0, + }, + "classifications": [], + }, + ] + }, + ], } @@ -603,9 +595,9 @@ def test_serialize_video_objects(): if key != "uuid": assert label[key] == manual_label[key] - assert len(label['segments']) == 2 - assert len(label['segments'][0]['keyframes']) == 2 - assert len(label['segments'][1]['keyframes']) == 4 + assert len(label["segments"]) == 2 + assert len(label["segments"][0]["keyframes"]) == 2 + assert len(label["segments"][1]["keyframes"]) == 4 # #converts back only the keyframes. should be the sum of all prev segments deserialized_labels = NDJsonConverter.deserialize([label]) @@ -618,7 +610,7 @@ def test_confidence_is_ignored(): serialized_labels = NDJsonConverter.serialize([label]) label = next(serialized_labels) label["confidence"] = 0.453 - label['segments'][0]["confidence"] = 0.453 + label["segments"][0]["confidence"] = 0.453 deserialized_labels = NDJsonConverter.deserialize([label]) label = next(deserialized_labels) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_free_text.py b/libs/labelbox/tests/data/serialization/ndjson/test_free_text.py index aaa84953a..84c017497 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_free_text.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_free_text.py @@ -1,5 +1,10 @@ from labelbox.data.annotation_types.annotation import ClassificationAnnotation -from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnswer, Radio, Text +from labelbox.data.annotation_types.classification.classification import ( + Checklist, + ClassificationAnswer, + Radio, + Text, +) from labelbox.data.annotation_types.data.text import TextData from labelbox.data.annotation_types.label import Label @@ -7,24 +12,27 @@ def test_serialization(): - label = Label(uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( - uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", - ), - annotations=[ - ClassificationAnnotation(name="free_text_annotation", - value=Text(confidence=0.5, - answer="text_answer")) - ]) + label = Label( + uid="ckj7z2q0b0000jx6x0q2q7q0d", + data=TextData( + uid="bkj7z2q0b0000jx6x0q2q7q0d", + text="This is a test", + ), + annotations=[ + ClassificationAnnotation( + name="free_text_annotation", + value=Text(confidence=0.5, answer="text_answer"), + ) + ], + ) serialized = NDJsonConverter.serialize([label]) res = next(serialized) - assert res['confidence'] == 0.5 - assert res['name'] == "free_text_annotation" - assert res['answer'] == "text_answer" - assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d" + assert res["confidence"] == 0.5 + assert res["name"] == "free_text_annotation" + assert res["answer"] == "text_answer" + assert res["dataRow"]["id"] == "bkj7z2q0b0000jx6x0q2q7q0d" deserialized = NDJsonConverter.deserialize([res]) res = next(deserialized) @@ -47,44 +55,53 @@ def test_nested_serialization(): annotations=[ ClassificationAnnotation( name="nested test", - value=Checklist(answer=[ - ClassificationAnswer( - name="first_answer", - confidence=0.9, - classifications=[ - ClassificationAnnotation( - name="sub_radio_question", - value=Radio(answer=ClassificationAnswer( - name="first_sub_radio_answer", - confidence=0.8, - classifications=[ - ClassificationAnnotation( - name="nested answer", - value=Text( - answer="nested answer", - confidence=0.7, - )) - ]))) - ]) - ]), + value=Checklist( + answer=[ + ClassificationAnswer( + name="first_answer", + confidence=0.9, + classifications=[ + ClassificationAnnotation( + name="sub_radio_question", + value=Radio( + answer=ClassificationAnswer( + name="first_sub_radio_answer", + confidence=0.8, + classifications=[ + ClassificationAnnotation( + name="nested answer", + value=Text( + answer="nested answer", + confidence=0.7, + ), + ) + ], + ) + ), + ) + ], + ) + ] + ), ) - ]) + ], + ) serialized = NDJsonConverter.serialize([label]) res = next(serialized) - assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d" - answer = res['answer'][0] - assert answer['confidence'] == 0.9 - assert answer['name'] == "first_answer" - classification = answer['classifications'][0] - nested_classification_answer = classification['answer'] - assert nested_classification_answer['confidence'] == 0.8 - assert nested_classification_answer['name'] == "first_sub_radio_answer" - sub_classification = nested_classification_answer['classifications'][0] - assert sub_classification['name'] == "nested answer" - assert sub_classification['answer'] == "nested answer" - assert sub_classification['confidence'] == 0.7 + assert res["dataRow"]["id"] == "bkj7z2q0b0000jx6x0q2q7q0d" + answer = res["answer"][0] + assert answer["confidence"] == 0.9 + assert answer["name"] == "first_answer" + classification = answer["classifications"][0] + nested_classification_answer = classification["answer"] + assert nested_classification_answer["confidence"] == 0.8 + assert nested_classification_answer["name"] == "first_sub_radio_answer" + sub_classification = nested_classification_answer["classifications"][0] + assert sub_classification["name"] == "nested answer" + assert sub_classification["answer"] == "nested answer" + assert sub_classification["confidence"] == 0.7 deserialized = NDJsonConverter.deserialize([res]) res = next(deserialized) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_global_key.py b/libs/labelbox/tests/data/serialization/ndjson/test_global_key.py index 6de2dcc51..2b3fa7f8c 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_global_key.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_global_key.py @@ -20,15 +20,18 @@ def round_dict(data): return data -@pytest.mark.parametrize('filename', [ - 'tests/data/assets/ndjson/classification_import_global_key.json', - 'tests/data/assets/ndjson/metric_import_global_key.json', - 'tests/data/assets/ndjson/polyline_import_global_key.json', - 'tests/data/assets/ndjson/text_entity_import_global_key.json', - 'tests/data/assets/ndjson/conversation_entity_import_global_key.json', -]) +@pytest.mark.parametrize( + "filename", + [ + "tests/data/assets/ndjson/classification_import_global_key.json", + "tests/data/assets/ndjson/metric_import_global_key.json", + "tests/data/assets/ndjson/polyline_import_global_key.json", + "tests/data/assets/ndjson/text_entity_import_global_key.json", + "tests/data/assets/ndjson/conversation_entity_import_global_key.json", + ], +) def test_many_types(filename: str): - with open(filename, 'r') as f: + with open(filename, "r") as f: data = json.load(f) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) @@ -37,19 +40,20 @@ def test_many_types(filename: str): def test_image(): - with open('tests/data/assets/ndjson/image_import_global_key.json', - 'r') as f: + with open( + "tests/data/assets/ndjson/image_import_global_key.json", "r" + ) as f: data = json.load(f) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) for r in res: - r.pop('classifications', None) + r.pop("classifications", None) assert [round_dict(x) for x in res] == [round_dict(x) for x in data] f.close() def test_pdf(): - with open('tests/data/assets/ndjson/pdf_import_global_key.json', 'r') as f: + with open("tests/data/assets/ndjson/pdf_import_global_key.json", "r") as f: data = json.load(f) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) @@ -58,8 +62,9 @@ def test_pdf(): def test_video(): - with open('tests/data/assets/ndjson/video_import_global_key.json', - 'r') as f: + with open( + "tests/data/assets/ndjson/video_import_global_key.json", "r" + ) as f: data = json.load(f) res = list(NDJsonConverter.deserialize(data)) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_image.py b/libs/labelbox/tests/data/serialization/ndjson/test_image.py index e36ce6f50..1729e1f46 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_image.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_image.py @@ -3,7 +3,13 @@ import cv2 from labelbox.data.serialization.ndjson.converter import NDJsonConverter -from labelbox.data.annotation_types import Mask, Label, ObjectAnnotation, ImageData, MaskData +from labelbox.data.annotation_types import ( + Mask, + Label, + ObjectAnnotation, + ImageData, + MaskData, +) def round_dict(data): @@ -20,61 +26,56 @@ def round_dict(data): def test_image(): - with open('tests/data/assets/ndjson/image_import.json', 'r') as file: + with open("tests/data/assets/ndjson/image_import.json", "r") as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) for r in res: - r.pop('classifications', None) + r.pop("classifications", None) assert [round_dict(x) for x in res] == [round_dict(x) for x in data] def test_image_with_name_only(): - with open('tests/data/assets/ndjson/image_import_name_only.json', - 'r') as file: + with open( + "tests/data/assets/ndjson/image_import_name_only.json", "r" + ) as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) for r in res: - r.pop('classifications', None) + r.pop("classifications", None) assert [round_dict(x) for x in res] == [round_dict(x) for x in data] def test_mask(): - data = [{ - "uuid": "b862c586-8614-483c-b5e6-82810f70cac0", - "schemaId": "ckrazcueb16og0z6609jj7y3y", - "dataRow": { - "id": "ckrazctum0z8a0ybc0b0o0g0v" + data = [ + { + "uuid": "b862c586-8614-483c-b5e6-82810f70cac0", + "schemaId": "ckrazcueb16og0z6609jj7y3y", + "dataRow": {"id": "ckrazctum0z8a0ybc0b0o0g0v"}, + "mask": { + "png": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAAAAACoWZBhAAAAMklEQVR4nD3MuQ3AQADDMOqQ/Vd2ijytaSiZLAcYuyLEYYYl9cvrlGftTHvsYl+u/3EDv0QLI8Z7FlwAAAAASUVORK5CYII=" + }, + "confidence": 0.8, + "customMetrics": [{"name": "customMetric1", "value": 0.4}], }, - "mask": { - "png": - "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAAAAACoWZBhAAAAMklEQVR4nD3MuQ3AQADDMOqQ/Vd2ijytaSiZLAcYuyLEYYYl9cvrlGftTHvsYl+u/3EDv0QLI8Z7FlwAAAAASUVORK5CYII=" - }, - "confidence": 0.8, - "customMetrics": [{ - "name": "customMetric1", - "value": 0.4 - }], - }, { - "uuid": "751fc725-f7b6-48ed-89b0-dd7d94d08af6", - "schemaId": "ckrazcuec16ok0z66f956apb7", - "dataRow": { - "id": "ckrazctum0z8a0ybc0b0o0g0v" + { + "uuid": "751fc725-f7b6-48ed-89b0-dd7d94d08af6", + "schemaId": "ckrazcuec16ok0z66f956apb7", + "dataRow": {"id": "ckrazctum0z8a0ybc0b0o0g0v"}, + "mask": { + "instanceURI": "https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY", + "colorRGB": [255, 0, 0], + }, }, - "mask": { - "instanceURI": - "https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY", - "colorRGB": [255, 0, 0] - } - }] + ] res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) for r in res: - r.pop('classifications', None) + r.pop("classifications", None) assert [round_dict(x) for x in res] == [round_dict(x) for x in data] @@ -83,22 +84,24 @@ def test_mask_from_arr(): mask_arr = np.round(np.zeros((32, 32))).astype(np.uint8) mask_arr = cv2.rectangle(mask_arr, (5, 5), (10, 10), (1, 1), -1) - label = Label(annotations=[ - ObjectAnnotation(feature_schema_id="1" * 25, - value=Mask(mask=MaskData.from_2D_arr(arr=mask_arr), - color=(1, 1, 1))) - ], - data=ImageData(uid="0" * 25)) + label = Label( + annotations=[ + ObjectAnnotation( + feature_schema_id="1" * 25, + value=Mask( + mask=MaskData.from_2D_arr(arr=mask_arr), color=(1, 1, 1) + ), + ) + ], + data=ImageData(uid="0" * 25), + ) res = next(NDJsonConverter.serialize([label])) res.pop("uuid") assert res == { "classifications": [], "schemaId": "1" * 25, - "dataRow": { - "id": "0" * 25 - }, + "dataRow": {"id": "0" * 25}, "mask": { - "png": - "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAAAAABWESUoAAAAHklEQVR4nGNgGAKAEYn8j00BEyETBoOCUTAKhhwAAJW+AQwvpePVAAAAAElFTkSuQmCC" - } + "png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAAAAABWESUoAAAAHklEQVR4nGNgGAKAEYn8j00BEyETBoOCUTAKhhwAAJW+AQwvpePVAAAAAElFTkSuQmCC" + }, } diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_metric.py b/libs/labelbox/tests/data/serialization/ndjson/test_metric.py index 6508b73af..45c5c67bf 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_metric.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_metric.py @@ -4,7 +4,7 @@ def test_metric(): - with open('tests/data/assets/ndjson/metric_import.json', 'r') as file: + with open("tests/data/assets/ndjson/metric_import.json", "r") as file: data = json.load(file) label_list = list(NDJsonConverter.deserialize(data)) @@ -13,22 +13,26 @@ def test_metric(): def test_custom_scalar_metric(): - with open('tests/data/assets/ndjson/custom_scalar_import.json', - 'r') as file: + with open( + "tests/data/assets/ndjson/custom_scalar_import.json", "r" + ) as file: data = json.load(file) label_list = list(NDJsonConverter.deserialize(data)) reserialized = list(NDJsonConverter.serialize(label_list)) - assert json.dumps(reserialized, - sort_keys=True) == json.dumps(data, sort_keys=True) + assert json.dumps(reserialized, sort_keys=True) == json.dumps( + data, sort_keys=True + ) def test_custom_confusion_matrix_metric(): - with open('tests/data/assets/ndjson/custom_confusion_matrix_import.json', - 'r') as file: + with open( + "tests/data/assets/ndjson/custom_confusion_matrix_import.json", "r" + ) as file: data = json.load(file) label_list = list(NDJsonConverter.deserialize(data)) reserialized = list(NDJsonConverter.serialize(label_list)) - assert json.dumps(reserialized, - sort_keys=True) == json.dumps(data, sort_keys=True) + assert json.dumps(reserialized, sort_keys=True) == json.dumps( + data, sort_keys=True + ) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_mmc.py b/libs/labelbox/tests/data/serialization/ndjson/test_mmc.py index bc093b79b..69594ff73 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_mmc.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_mmc.py @@ -6,7 +6,7 @@ def test_message_task_annotation_serialization(): - with open('tests/data/assets/ndjson/mmc_import.json', 'r') as file: + with open("tests/data/assets/ndjson/mmc_import.json", "r") as file: data = json.load(file) deserialized = list(NDJsonConverter.deserialize(data)) @@ -16,14 +16,17 @@ def test_message_task_annotation_serialization(): def test_mesage_ranking_task_wrong_order_serialization(): - with open('tests/data/assets/ndjson/mmc_import.json', 'r') as file: + with open("tests/data/assets/ndjson/mmc_import.json", "r") as file: data = json.load(file) some_ranking_task = next( - task for task in data - if task["messageEvaluationTask"]["format"] == "message-ranking") + task + for task in data + if task["messageEvaluationTask"]["format"] == "message-ranking" + ) some_ranking_task["messageEvaluationTask"]["data"]["rankedMessages"][0][ - "order"] = 3 + "order" + ] = 3 with pytest.raises(ValueError): list(NDJsonConverter.deserialize([some_ranking_task])) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_ndlabel_subclass_matching.py b/libs/labelbox/tests/data/serialization/ndjson/test_ndlabel_subclass_matching.py index 1f51c307a..790bd87b3 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_ndlabel_subclass_matching.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_ndlabel_subclass_matching.py @@ -5,13 +5,15 @@ def test_bad_annotation_input(): - data = [{ - "test": 3 - }] + data = [{"test": 3}] with pytest.raises(ValueError): NDLabel(**{"annotations": data}) + def test_correct_annotation_input(): - with open('tests/data/assets/ndjson/pdf_import_name_only.json', 'r') as f: + with open("tests/data/assets/ndjson/pdf_import_name_only.json", "r") as f: data = json.load(f) - assert isinstance(NDLabel(**{"annotations": [data[0]]}).annotations[0], NDDocumentRectangle) + assert isinstance( + NDLabel(**{"annotations": [data[0]]}).annotations[0], + NDDocumentRectangle, + ) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_nested.py b/libs/labelbox/tests/data/serialization/ndjson/test_nested.py index 69fddf1ff..e0f0df0e6 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_nested.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_nested.py @@ -4,7 +4,7 @@ def test_nested(): - with open('tests/data/assets/ndjson/nested_import.json', 'r') as file: + with open("tests/data/assets/ndjson/nested_import.json", "r") as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) @@ -12,8 +12,9 @@ def test_nested(): def test_nested_name_only(): - with open('tests/data/assets/ndjson/nested_import_name_only.json', - 'r') as file: + with open( + "tests/data/assets/ndjson/nested_import_name_only.json", "r" + ) as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_polyline.py b/libs/labelbox/tests/data/serialization/ndjson/test_polyline.py index 933c378df..97d48a14e 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_polyline.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_polyline.py @@ -3,12 +3,15 @@ from labelbox.data.serialization.ndjson.converter import NDJsonConverter -@pytest.mark.parametrize("filename", [ - "tests/data/assets/ndjson/polyline_without_confidence_import.json", - "tests/data/assets/ndjson/polyline_import.json" -]) +@pytest.mark.parametrize( + "filename", + [ + "tests/data/assets/ndjson/polyline_without_confidence_import.json", + "tests/data/assets/ndjson/polyline_import.json", + ], +) def test_polyline_import(filename: str): - with open(filename, 'r') as file: + with open(filename, "r") as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_radio.py b/libs/labelbox/tests/data/serialization/ndjson/test_radio.py index 97cb073e0..bd80f9267 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_radio.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_radio.py @@ -1,6 +1,8 @@ import json from labelbox.data.annotation_types.annotation import ClassificationAnnotation -from labelbox.data.annotation_types.classification.classification import ClassificationAnswer +from labelbox.data.annotation_types.classification.classification import ( + ClassificationAnswer, +) from labelbox.data.annotation_types.classification.classification import Radio from labelbox.data.annotation_types.data.text import TextData from labelbox.data.annotation_types.label import Label @@ -19,17 +21,18 @@ def test_serialization_with_radio_min(): ClassificationAnnotation( name="radio_question_geo", value=Radio( - answer=ClassificationAnswer(name="first_radio_answer",))) - ]) + answer=ClassificationAnswer( + name="first_radio_answer", + ) + ), + ) + ], + ) expected = { - 'name': 'radio_question_geo', - 'answer': { - 'name': 'first_radio_answer' - }, - 'dataRow': { - 'id': 'bkj7z2q0b0000jx6x0q2q7q0d' - } + "name": "radio_question_geo", + "answer": {"name": "first_radio_answer"}, + "dataRow": {"id": "bkj7z2q0b0000jx6x0q2q7q0d"}, } serialized = NDJsonConverter.serialize([label]) res = next(serialized) @@ -47,43 +50,51 @@ def test_serialization_with_radio_min(): def test_serialization_with_radio_classification(): - label = Label(uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( - uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", - ), - annotations=[ - ClassificationAnnotation( - name="radio_question_geo", - confidence=0.5, - value=Radio(answer=ClassificationAnswer( - confidence=0.6, - name="first_radio_answer", - classifications=[ - ClassificationAnnotation( - name="sub_radio_question", - value=Radio(answer=ClassificationAnswer( - name="first_sub_radio_answer"))) - ]))) - ]) + label = Label( + uid="ckj7z2q0b0000jx6x0q2q7q0d", + data=TextData( + uid="bkj7z2q0b0000jx6x0q2q7q0d", + text="This is a test", + ), + annotations=[ + ClassificationAnnotation( + name="radio_question_geo", + confidence=0.5, + value=Radio( + answer=ClassificationAnswer( + confidence=0.6, + name="first_radio_answer", + classifications=[ + ClassificationAnnotation( + name="sub_radio_question", + value=Radio( + answer=ClassificationAnswer( + name="first_sub_radio_answer" + ) + ), + ) + ], + ) + ), + ) + ], + ) expected = { - 'confidence': 0.5, - 'name': 'radio_question_geo', - 'answer': { - 'confidence': - 0.6, - 'name': - 'first_radio_answer', - 'classifications': [{ - 'name': 'sub_radio_question', - 'answer': { - 'name': 'first_sub_radio_answer', + "confidence": 0.5, + "name": "radio_question_geo", + "answer": { + "confidence": 0.6, + "name": "first_radio_answer", + "classifications": [ + { + "name": "sub_radio_question", + "answer": { + "name": "first_sub_radio_answer", + }, } - }] + ], }, - 'dataRow': { - 'id': 'bkj7z2q0b0000jx6x0q2q7q0d' - } + "dataRow": {"id": "bkj7z2q0b0000jx6x0q2q7q0d"}, } serialized = NDJsonConverter.serialize([label]) @@ -94,5 +105,6 @@ def test_serialization_with_radio_classification(): deserialized = NDJsonConverter.deserialize([res]) res = next(deserialized) res.annotations[0].extra.pop("uuid") - assert res.annotations[0].model_dump(exclude_none=True) == label.annotations[0].model_dump(exclude_none=True) - + assert res.annotations[0].model_dump( + exclude_none=True + ) == label.annotations[0].model_dump(exclude_none=True) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py b/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py index c07dcc66d..66630dbb5 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py @@ -6,7 +6,7 @@ def test_rectangle(): - with open('tests/data/assets/ndjson/rectangle_import.json', 'r') as file: + with open("tests/data/assets/ndjson/rectangle_import.json", "r") as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) @@ -14,7 +14,7 @@ def test_rectangle(): def test_rectangle_inverted_start_end_points(): - with open('tests/data/assets/ndjson/rectangle_import.json', 'r') as file: + with open("tests/data/assets/ndjson/rectangle_import.json", "r") as file: data = json.load(file) bbox = lb_types.ObjectAnnotation( @@ -23,10 +23,10 @@ def test_rectangle_inverted_start_end_points(): start=lb_types.Point(x=81, y=69), end=lb_types.Point(x=38, y=28), ), - extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}) + extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}, + ) - label = lb_types.Label(data={"uid":DATAROW_ID}, - annotations=[bbox]) + label = lb_types.Label(data={"uid": DATAROW_ID}, annotations=[bbox]) res = list(NDJsonConverter.serialize([label])) assert res == data @@ -40,18 +40,20 @@ def test_rectangle_inverted_start_end_points(): extra={ "uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72", "page": None, - "unit": None - }) + "unit": None, + }, + ) - label = lb_types.Label(data={"uid":DATAROW_ID}, - annotations=[expected_bbox]) + label = lb_types.Label( + data={"uid": DATAROW_ID}, annotations=[expected_bbox] + ) res = list(NDJsonConverter.deserialize(res)) assert res == [label] def test_rectangle_mixed_start_end_points(): - with open('tests/data/assets/ndjson/rectangle_import.json', 'r') as file: + with open("tests/data/assets/ndjson/rectangle_import.json", "r") as file: data = json.load(file) bbox = lb_types.ObjectAnnotation( @@ -60,10 +62,10 @@ def test_rectangle_mixed_start_end_points(): start=lb_types.Point(x=81, y=28), end=lb_types.Point(x=38, y=69), ), - extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}) + extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}, + ) - label = lb_types.Label(data={"uid":DATAROW_ID}, - annotations=[bbox]) + label = lb_types.Label(data={"uid": DATAROW_ID}, annotations=[bbox]) res = list(NDJsonConverter.serialize([label])) assert res == data @@ -77,11 +79,11 @@ def test_rectangle_mixed_start_end_points(): extra={ "uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72", "page": None, - "unit": None - }) + "unit": None, + }, + ) - label = lb_types.Label(data={"uid":DATAROW_ID}, - annotations=[bbox]) + label = lb_types.Label(data={"uid": DATAROW_ID}, annotations=[bbox]) res = list(NDJsonConverter.deserialize(res)) assert res == [label] @@ -94,13 +96,13 @@ def test_benchmark_reference_label_flag_enabled(): start=lb_types.Point(x=81, y=28), end=lb_types.Point(x=38, y=69), ), - extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"} + extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}, ) label = lb_types.Label( - data={"uid":DATAROW_ID}, + data={"uid": DATAROW_ID}, annotations=[bbox], - is_benchmark_reference=True + is_benchmark_reference=True, ) res = list(NDJsonConverter.serialize([label])) @@ -114,13 +116,13 @@ def test_benchmark_reference_label_flag_disabled(): start=lb_types.Point(x=81, y=28), end=lb_types.Point(x=38, y=69), ), - extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"} + extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}, ) label = lb_types.Label( - data={"uid":DATAROW_ID}, + data={"uid": DATAROW_ID}, annotations=[bbox], - is_benchmark_reference=False + is_benchmark_reference=False, ) res = list(NDJsonConverter.serialize([label])) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_relationship.py b/libs/labelbox/tests/data/serialization/ndjson/test_relationship.py index 9ede41d2c..f33719035 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_relationship.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_relationship.py @@ -30,10 +30,14 @@ def test_relationship(): ] assert res_relationship_second_annotation - assert res_relationship_second_annotation["relationship"][ - "source"] != res_relationship_annotation["relationship"]["source"] - assert res_relationship_second_annotation["relationship"][ - "target"] != res_relationship_annotation["relationship"]["target"] + assert ( + res_relationship_second_annotation["relationship"]["source"] + != res_relationship_annotation["relationship"]["source"] + ) + assert ( + res_relationship_second_annotation["relationship"]["target"] + != res_relationship_annotation["relationship"]["target"] + ) assert res_relationship_second_annotation["relationship"]["source"] in [ annot["uuid"] for annot in res_source_and_target ] diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_text.py b/libs/labelbox/tests/data/serialization/ndjson/test_text.py index 534068e14..d5e81c51a 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_text.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_text.py @@ -1,5 +1,9 @@ from labelbox.data.annotation_types.annotation import ClassificationAnnotation -from labelbox.data.annotation_types.classification.classification import ClassificationAnswer, Radio, Text +from labelbox.data.annotation_types.classification.classification import ( + ClassificationAnswer, + Radio, + Text, +) from labelbox.data.annotation_types.data.text import TextData from labelbox.data.annotation_types.label import Label @@ -7,24 +11,29 @@ def test_serialization(): - label = Label(uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( - uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", - ), - annotations=[ - ClassificationAnnotation( - name="radio_question_geo", - confidence=0.5, - value=Text(answer="first_radio_answer")) - ]) + label = Label( + uid="ckj7z2q0b0000jx6x0q2q7q0d", + data=TextData( + uid="bkj7z2q0b0000jx6x0q2q7q0d", + text="This is a test", + ), + annotations=[ + ClassificationAnnotation( + name="radio_question_geo", + confidence=0.5, + value=Text(answer="first_radio_answer"), + ) + ], + ) serialized = NDJsonConverter.serialize([label]) res = next(serialized) - assert 'confidence' not in res # because confidence needs to be set on the annotation itself - assert res['name'] == "radio_question_geo" - assert res['answer'] == "first_radio_answer" - assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d" + assert ( + "confidence" not in res + ) # because confidence needs to be set on the annotation itself + assert res["name"] == "radio_question_geo" + assert res["answer"] == "first_radio_answer" + assert res["dataRow"]["id"] == "bkj7z2q0b0000jx6x0q2q7q0d" deserialized = NDJsonConverter.deserialize([res]) res = next(deserialized) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_text_entity.py b/libs/labelbox/tests/data/serialization/ndjson/test_text_entity.py index f62d87ebc..3e856f001 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_text_entity.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_text_entity.py @@ -5,12 +5,15 @@ from labelbox.data.serialization.ndjson.converter import NDJsonConverter -@pytest.mark.parametrize("filename", [ - "tests/data/assets/ndjson/text_entity_import.json", - "tests/data/assets/ndjson/text_entity_without_confidence_import.json" -]) +@pytest.mark.parametrize( + "filename", + [ + "tests/data/assets/ndjson/text_entity_import.json", + "tests/data/assets/ndjson/text_entity_without_confidence_import.json", + ], +) def test_text_entity_import(filename: str): - with open(filename, 'r') as file: + with open(filename, "r") as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_video.py b/libs/labelbox/tests/data/serialization/ndjson/test_video.py index 4b90a8060..c7a6535c4 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_video.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_video.py @@ -1,6 +1,11 @@ import json from labelbox.client import Client -from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnnotation, ClassificationAnswer, Radio +from labelbox.data.annotation_types.classification.classification import ( + Checklist, + ClassificationAnnotation, + ClassificationAnswer, + Radio, +) from labelbox.data.annotation_types.data.video import VideoData from labelbox.data.annotation_types.geometry.line import Line from labelbox.data.annotation_types.geometry.point import Point @@ -16,29 +21,31 @@ def test_video(): - with open('tests/data/assets/ndjson/video_import.json', 'r') as file: + with open("tests/data/assets/ndjson/video_import.json", "r") as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) - - data = sorted(data, key=itemgetter('uuid')) - res = sorted(res, key=itemgetter('uuid')) + + data = sorted(data, key=itemgetter("uuid")) + res = sorted(res, key=itemgetter("uuid")) pairs = zip(data, res) for data, res in pairs: assert data == res + def test_video_name_only(): - with open('tests/data/assets/ndjson/video_import_name_only.json', - 'r') as file: + with open( + "tests/data/assets/ndjson/video_import_name_only.json", "r" + ) as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) - - data = sorted(data, key=itemgetter('uuid')) - res = sorted(res, key=itemgetter('uuid')) + + data = sorted(data, key=itemgetter("uuid")) + res = sorted(res, key=itemgetter("uuid")) pairs = zip(data, res) for data, res in pairs: @@ -47,54 +54,60 @@ def test_video_name_only(): def test_video_classification_global_subclassifications(): label = Label( - data=VideoData(global_key="sample-video-4.mp4",), + data=VideoData( + global_key="sample-video-4.mp4", + ), annotations=[ ClassificationAnnotation( - name='radio_question_nested', - value=Radio(answer=ClassificationAnswer( - name='first_radio_question')), + name="radio_question_nested", + value=Radio( + answer=ClassificationAnswer(name="first_radio_question") + ), ), ClassificationAnnotation( - name='nested_checklist_question', + name="nested_checklist_question", value=Checklist( - name='checklist', + name="checklist", answer=[ ClassificationAnswer( - name='first_checklist_answer', + name="first_checklist_answer", classifications=[ ClassificationAnnotation( - name='sub_checklist_question', - value=Radio(answer=ClassificationAnswer( - name='first_sub_checklist_answer'))) - ]) - ])) - ]) + name="sub_checklist_question", + value=Radio( + answer=ClassificationAnswer( + name="first_sub_checklist_answer" + ) + ), + ) + ], + ) + ], + ), + ), + ], + ) expected_first_annotation = { - 'name': 'radio_question_nested', - 'answer': { - 'name': 'first_radio_question' - }, - 'dataRow': { - 'globalKey': 'sample-video-4.mp4' - } + "name": "radio_question_nested", + "answer": {"name": "first_radio_question"}, + "dataRow": {"globalKey": "sample-video-4.mp4"}, } expected_second_annotation = nested_checklist_annotation_ndjson = { "name": "nested_checklist_question", - "answer": [{ - "name": - "first_checklist_answer", - "classifications": [{ - "name": "sub_checklist_question", - "answer": { - "name": "first_sub_checklist_answer" - } - }] - }], - 'dataRow': { - 'globalKey': 'sample-video-4.mp4' - } + "answer": [ + { + "name": "first_checklist_answer", + "classifications": [ + { + "name": "sub_checklist_question", + "answer": {"name": "first_sub_checklist_answer"}, + } + ], + } + ], + "dataRow": {"globalKey": "sample-video-4.mp4"}, } serialized = NDJsonConverter.serialize([label]) @@ -123,18 +136,27 @@ def test_video_classification_nesting_bbox(): ), classifications=[ ClassificationAnnotation( - name='radio_question_nested', - value=Radio(answer=ClassificationAnswer( - name='first_radio_question', - classifications=[ - ClassificationAnnotation(name='sub_question_radio', - value=Checklist(answer=[ - ClassificationAnswer( - name='sub_answer') - ])) - ])), + name="radio_question_nested", + value=Radio( + answer=ClassificationAnswer( + name="first_radio_question", + classifications=[ + ClassificationAnnotation( + name="sub_question_radio", + value=Checklist( + answer=[ + ClassificationAnswer( + name="sub_answer" + ) + ] + ), + ) + ], + ) + ), ) - ]), + ], + ), VideoObjectAnnotation( name="bbox_video", keyframe=True, @@ -146,18 +168,27 @@ def test_video_classification_nesting_bbox(): ), classifications=[ ClassificationAnnotation( - name='nested_checklist_question', - value=Checklist(answer=[ - ClassificationAnswer( - name='first_checklist_answer', - classifications=[ - ClassificationAnnotation( - name='sub_checklist_question', - value=Radio(answer=ClassificationAnswer( - name='first_sub_checklist_answer'))) - ]) - ])) - ]), + name="nested_checklist_question", + value=Checklist( + answer=[ + ClassificationAnswer( + name="first_checklist_answer", + classifications=[ + ClassificationAnnotation( + name="sub_checklist_question", + value=Radio( + answer=ClassificationAnswer( + name="first_sub_checklist_answer" + ) + ), + ) + ], + ) + ] + ), + ) + ], + ), VideoObjectAnnotation( name="bbox_video", keyframe=True, @@ -166,76 +197,91 @@ def test_video_classification_nesting_bbox(): value=Rectangle( start=Point(x=146.0, y=98.0), # Top left end=Point(x=382.0, y=341.0), # Bottom right - )) + ), + ), ] - expected = [{ - 'dataRow': { - 'globalKey': 'sample-video-4.mp4' - }, - 'name': - 'bbox_video', - 'classifications': [], - 'segments': [{ - 'keyframes': [{ - 'frame': - 13, - 'bbox': { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - }, - 'classifications': [{ - 'name': 'radio_question_nested', - 'answer': { - 'name': - 'first_radio_question', - 'classifications': [{ - 'name': 'sub_question_radio', - 'answer': [{ - 'name': 'sub_answer' - }] - }] - } - }] - }, { - 'frame': - 15, - 'bbox': { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - }, - 'classifications': [{ - 'name': - 'nested_checklist_question', - 'answer': [{ - 'name': - 'first_checklist_answer', - 'classifications': [{ - 'name': 'sub_checklist_question', - 'answer': { - 'name': 'first_sub_checklist_answer' - } - }] - }] - }] - }, { - 'frame': 19, - 'bbox': { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - }, - 'classifications': [] - }] - }] - }] - - label = Label(data=VideoData(global_key="sample-video-4.mp4",), - annotations=bbox_annotation) + expected = [ + { + "dataRow": {"globalKey": "sample-video-4.mp4"}, + "name": "bbox_video", + "classifications": [], + "segments": [ + { + "keyframes": [ + { + "frame": 13, + "bbox": { + "top": 98.0, + "left": 146.0, + "height": 243.0, + "width": 236.0, + }, + "classifications": [ + { + "name": "radio_question_nested", + "answer": { + "name": "first_radio_question", + "classifications": [ + { + "name": "sub_question_radio", + "answer": [ + {"name": "sub_answer"} + ], + } + ], + }, + } + ], + }, + { + "frame": 15, + "bbox": { + "top": 98.0, + "left": 146.0, + "height": 243.0, + "width": 236.0, + }, + "classifications": [ + { + "name": "nested_checklist_question", + "answer": [ + { + "name": "first_checklist_answer", + "classifications": [ + { + "name": "sub_checklist_question", + "answer": { + "name": "first_sub_checklist_answer" + }, + } + ], + } + ], + } + ], + }, + { + "frame": 19, + "bbox": { + "top": 98.0, + "left": 146.0, + "height": 243.0, + "width": 236.0, + }, + "classifications": [], + }, + ] + } + ], + } + ] + + label = Label( + data=VideoData( + global_key="sample-video-4.mp4", + ), + annotations=bbox_annotation, + ) serialized = NDJsonConverter.serialize([label]) res = [x for x in serialized] @@ -260,18 +306,27 @@ def test_video_classification_point(): value=Point(x=46.0, y=8.0), classifications=[ ClassificationAnnotation( - name='radio_question_nested', - value=Radio(answer=ClassificationAnswer( - name='first_radio_question', - classifications=[ - ClassificationAnnotation(name='sub_question_radio', - value=Checklist(answer=[ - ClassificationAnswer( - name='sub_answer') - ])) - ])), + name="radio_question_nested", + value=Radio( + answer=ClassificationAnswer( + name="first_radio_question", + classifications=[ + ClassificationAnnotation( + name="sub_question_radio", + value=Checklist( + answer=[ + ClassificationAnswer( + name="sub_answer" + ) + ] + ), + ) + ], + ) + ), ) - ]), + ], + ), VideoObjectAnnotation( name="bbox_video", keyframe=True, @@ -280,88 +335,111 @@ def test_video_classification_point(): value=Point(x=56.0, y=18.0), classifications=[ ClassificationAnnotation( - name='nested_checklist_question', - value=Checklist(answer=[ - ClassificationAnswer( - name='first_checklist_answer', - classifications=[ - ClassificationAnnotation( - name='sub_checklist_question', - value=Radio(answer=ClassificationAnswer( - name='first_sub_checklist_answer'))) - ]) - ])) - ]), + name="nested_checklist_question", + value=Checklist( + answer=[ + ClassificationAnswer( + name="first_checklist_answer", + classifications=[ + ClassificationAnnotation( + name="sub_checklist_question", + value=Radio( + answer=ClassificationAnswer( + name="first_sub_checklist_answer" + ) + ), + ) + ], + ) + ] + ), + ) + ], + ), VideoObjectAnnotation( name="bbox_video", keyframe=True, frame=19, segment_index=0, value=Point(x=66.0, y=28.0), - ) + ), + ] + expected = [ + { + "dataRow": {"globalKey": "sample-video-4.mp4"}, + "name": "bbox_video", + "classifications": [], + "segments": [ + { + "keyframes": [ + { + "frame": 13, + "point": { + "x": 46.0, + "y": 8.0, + }, + "classifications": [ + { + "name": "radio_question_nested", + "answer": { + "name": "first_radio_question", + "classifications": [ + { + "name": "sub_question_radio", + "answer": [ + {"name": "sub_answer"} + ], + } + ], + }, + } + ], + }, + { + "frame": 15, + "point": { + "x": 56.0, + "y": 18.0, + }, + "classifications": [ + { + "name": "nested_checklist_question", + "answer": [ + { + "name": "first_checklist_answer", + "classifications": [ + { + "name": "sub_checklist_question", + "answer": { + "name": "first_sub_checklist_answer" + }, + } + ], + } + ], + } + ], + }, + { + "frame": 19, + "point": { + "x": 66.0, + "y": 28.0, + }, + "classifications": [], + }, + ] + } + ], + } ] - expected = [{ - 'dataRow': { - 'globalKey': 'sample-video-4.mp4' - }, - 'name': - 'bbox_video', - 'classifications': [], - 'segments': [{ - 'keyframes': [{ - 'frame': - 13, - 'point': { - 'x': 46.0, - 'y': 8.0, - }, - 'classifications': [{ - 'name': 'radio_question_nested', - 'answer': { - 'name': - 'first_radio_question', - 'classifications': [{ - 'name': 'sub_question_radio', - 'answer': [{ - 'name': 'sub_answer' - }] - }] - } - }] - }, { - 'frame': - 15, - 'point': { - 'x': 56.0, - 'y': 18.0, - }, - 'classifications': [{ - 'name': - 'nested_checklist_question', - 'answer': [{ - 'name': - 'first_checklist_answer', - 'classifications': [{ - 'name': 'sub_checklist_question', - 'answer': { - 'name': 'first_sub_checklist_answer' - } - }] - }] - }] - }, { - 'frame': 19, - 'point': { - 'x': 66.0, - 'y': 28.0, - }, - 'classifications': [] - }] - }] - }] - - label = Label(data=VideoData(global_key="sample-video-4.mp4",), - annotations=bbox_annotation) + + label = Label( + data=VideoData( + global_key="sample-video-4.mp4", + ), + annotations=bbox_annotation, + ) serialized = NDJsonConverter.serialize([label]) res = [x for x in serialized] @@ -382,123 +460,161 @@ def test_video_classification_frameline(): keyframe=True, frame=13, segment_index=0, - value=Line( - points=[Point(x=8, y=10), Point(x=10, y=9)]), + value=Line(points=[Point(x=8, y=10), Point(x=10, y=9)]), classifications=[ ClassificationAnnotation( - name='radio_question_nested', - value=Radio(answer=ClassificationAnswer( - name='first_radio_question', - classifications=[ - ClassificationAnnotation(name='sub_question_radio', - value=Checklist(answer=[ - ClassificationAnswer( - name='sub_answer') - ])) - ])), + name="radio_question_nested", + value=Radio( + answer=ClassificationAnswer( + name="first_radio_question", + classifications=[ + ClassificationAnnotation( + name="sub_question_radio", + value=Checklist( + answer=[ + ClassificationAnswer( + name="sub_answer" + ) + ] + ), + ) + ], + ) + ), ) - ]), + ], + ), VideoObjectAnnotation( name="bbox_video", keyframe=True, frame=15, segment_index=0, - value=Line( - points=[Point(x=18, y=20), Point(x=20, y=19)]), + value=Line(points=[Point(x=18, y=20), Point(x=20, y=19)]), classifications=[ ClassificationAnnotation( - name='nested_checklist_question', - value=Checklist(answer=[ - ClassificationAnswer( - name='first_checklist_answer', - classifications=[ - ClassificationAnnotation( - name='sub_checklist_question', - value=Radio(answer=ClassificationAnswer( - name='first_sub_checklist_answer'))) - ]) - ])) - ]), + name="nested_checklist_question", + value=Checklist( + answer=[ + ClassificationAnswer( + name="first_checklist_answer", + classifications=[ + ClassificationAnnotation( + name="sub_checklist_question", + value=Radio( + answer=ClassificationAnswer( + name="first_sub_checklist_answer" + ) + ), + ) + ], + ) + ] + ), + ) + ], + ), VideoObjectAnnotation( name="bbox_video", keyframe=True, frame=19, segment_index=0, - value=Line( - points=[Point(x=28, y=30), Point(x=30, y=29)]), - ) + value=Line(points=[Point(x=28, y=30), Point(x=30, y=29)]), + ), ] - expected = [{ - 'dataRow': { - 'globalKey': 'sample-video-4.mp4' - }, - 'name': - 'bbox_video', - 'classifications': [], - 'segments': [{ - 'keyframes': [{ - 'frame': - 13, - 'line': [{ - 'x': 8.0, - 'y': 10.0, - }, { - 'x': 10.0, - 'y': 9.0, - }], - 'classifications': [{ - 'name': 'radio_question_nested', - 'answer': { - 'name': - 'first_radio_question', - 'classifications': [{ - 'name': 'sub_question_radio', - 'answer': [{ - 'name': 'sub_answer' - }] - }] - } - }] - }, { - 'frame': - 15, - 'line': [{ - 'x': 18.0, - 'y': 20.0, - }, { - 'x': 20.0, - 'y': 19.0, - }], - 'classifications': [{ - 'name': - 'nested_checklist_question', - 'answer': [{ - 'name': - 'first_checklist_answer', - 'classifications': [{ - 'name': 'sub_checklist_question', - 'answer': { - 'name': 'first_sub_checklist_answer' - } - }] - }] - }] - }, { - 'frame': 19, - 'line': [{ - 'x': 28.0, - 'y': 30.0, - }, { - 'x': 30.0, - 'y': 29.0, - }], - 'classifications': [] - }] - }] - }] - - label = Label(data=VideoData(global_key="sample-video-4.mp4",), - annotations=bbox_annotation) + expected = [ + { + "dataRow": {"globalKey": "sample-video-4.mp4"}, + "name": "bbox_video", + "classifications": [], + "segments": [ + { + "keyframes": [ + { + "frame": 13, + "line": [ + { + "x": 8.0, + "y": 10.0, + }, + { + "x": 10.0, + "y": 9.0, + }, + ], + "classifications": [ + { + "name": "radio_question_nested", + "answer": { + "name": "first_radio_question", + "classifications": [ + { + "name": "sub_question_radio", + "answer": [ + {"name": "sub_answer"} + ], + } + ], + }, + } + ], + }, + { + "frame": 15, + "line": [ + { + "x": 18.0, + "y": 20.0, + }, + { + "x": 20.0, + "y": 19.0, + }, + ], + "classifications": [ + { + "name": "nested_checklist_question", + "answer": [ + { + "name": "first_checklist_answer", + "classifications": [ + { + "name": "sub_checklist_question", + "answer": { + "name": "first_sub_checklist_answer" + }, + } + ], + } + ], + } + ], + }, + { + "frame": 19, + "line": [ + { + "x": 28.0, + "y": 30.0, + }, + { + "x": 30.0, + "y": 29.0, + }, + ], + "classifications": [], + }, + ] + } + ], + } + ] + + label = Label( + data=VideoData( + global_key="sample-video-4.mp4", + ), + annotations=bbox_annotation, + ) serialized = NDJsonConverter.serialize([label]) res = [x for x in serialized] assert res == expected diff --git a/libs/labelbox/tests/data/test_data_row_metadata.py b/libs/labelbox/tests/data/test_data_row_metadata.py index 1cadc4376..9a3690776 100644 --- a/libs/labelbox/tests/data/test_data_row_metadata.py +++ b/libs/labelbox/tests/data/test_data_row_metadata.py @@ -6,7 +6,13 @@ from labelbox import Dataset from labelbox.exceptions import MalformedQueryException from labelbox.schema.identifiables import GlobalKeys, UniqueIds -from labelbox.schema.data_row_metadata import DataRowMetadataField, DataRowMetadata, DataRowMetadataKind, DataRowMetadataOntology, _parse_metadata_schema +from labelbox.schema.data_row_metadata import ( + DataRowMetadataField, + DataRowMetadata, + DataRowMetadataKind, + DataRowMetadataOntology, + _parse_metadata_schema, +) INVALID_SCHEMA_ID = "1" * 25 FAKE_SCHEMA_ID = "0" * 25 @@ -16,13 +22,13 @@ TEST_SPLIT_ID = "cko8scbz70005h2dkastwhgqt" TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh" CAPTURE_DT_SCHEMA_ID = "cko8sdzv70006h2dk8jg64zvb" -CUSTOM_TEXT_SCHEMA_NAME = 'custom_text' +CUSTOM_TEXT_SCHEMA_NAME = "custom_text" FAKE_NUMBER_FIELD = { "id": FAKE_SCHEMA_ID, "name": "number", - "kind": 'CustomMetadataNumber', - "reserved": False + "kind": "CustomMetadataNumber", + "reserved": False, } @@ -42,12 +48,12 @@ def mdo(client): @pytest.fixture def big_dataset(dataset: Dataset, image_url): - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image" - }, - ] * 5) + task = dataset.create_data_rows( + [ + {"row_data": image_url, "external_id": "my-image"}, + ] + * 5 + ) task.wait_till_done() yield dataset @@ -61,11 +67,13 @@ def make_metadata(dr_id: str = None, gk: str = None) -> DataRowMetadata: global_key=gk, data_row_id=dr_id, fields=[ - DataRowMetadataField(schema_id=SPLIT_SCHEMA_ID, - value=TEST_SPLIT_ID), + DataRowMetadataField( + schema_id=SPLIT_SCHEMA_ID, value=TEST_SPLIT_ID + ), DataRowMetadataField(schema_id=CAPTURE_DT_SCHEMA_ID, value=time), DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value=msg), - ]) + ], + ) return metadata @@ -73,29 +81,29 @@ def make_named_metadata(dr_id) -> DataRowMetadata: msg = "A message" time = datetime.utcnow() - metadata = DataRowMetadata(data_row_id=dr_id, - fields=[ - DataRowMetadataField(name='split', - value=TEST_SPLIT_ID), - DataRowMetadataField(name='captureDateTime', - value=time), - DataRowMetadataField( - name=CUSTOM_TEXT_SCHEMA_NAME, value=msg), - ]) + metadata = DataRowMetadata( + data_row_id=dr_id, + fields=[ + DataRowMetadataField(name="split", value=TEST_SPLIT_ID), + DataRowMetadataField(name="captureDateTime", value=time), + DataRowMetadataField(name=CUSTOM_TEXT_SCHEMA_NAME, value=msg), + ], + ) return metadata @pytest.mark.skip(reason="broken export v1 api, to be retired soon") -def test_export_empty_metadata(client, configured_project_with_label, - wait_for_data_row_processing): +def test_export_empty_metadata( + client, configured_project_with_label, wait_for_data_row_processing +): project, _, data_row, _ = configured_project_with_label data_row = wait_for_data_row_processing(client, data_row) - + export_task = project.export(params={"metadata_fields": True}) export_task.wait_till_done() stream = export_task.get_buffered_stream() data_row = [data_row.json for data_row in stream][0] - + assert data_row["metadata_fields"] == [] @@ -134,9 +142,11 @@ def test_get_datarow_metadata_ontology(mdo): value=datetime.utcnow(), ), DataRowMetadataField(schema_id=split.parent, value=split.uid), - DataRowMetadataField(schema_id=mdo.reserved_by_name["tag"].uid, - value="hello-world"), - ]) + DataRowMetadataField( + schema_id=mdo.reserved_by_name["tag"].uid, value="hello-world" + ), + ], + ) def test_bulk_upsert_datarow_metadata(data_row, mdo: DataRowMetadataOntology): @@ -148,7 +158,8 @@ def test_bulk_upsert_datarow_metadata(data_row, mdo: DataRowMetadataOntology): def test_bulk_upsert_datarow_metadata_by_globalkey( - data_rows, mdo: DataRowMetadataOntology): + data_rows, mdo: DataRowMetadataOntology +): global_keys = [data_row.global_key for data_row in data_rows] metadata = [make_metadata(gk=global_key) for global_key in global_keys] errors = mdo.bulk_upsert(metadata) @@ -169,8 +180,9 @@ def test_large_bulk_upsert_datarow_metadata(big_dataset, mdo): for metadata in mdo.bulk_export(data_row_ids) } for data_row_id in data_row_ids: - assert len([f for f in metadata_lookup.get(data_row_id).fields - ]), metadata_lookup.get(data_row_id).fields + assert len( + [f for f in metadata_lookup.get(data_row_id).fields] + ), metadata_lookup.get(data_row_id).fields def test_upsert_datarow_metadata_by_name(data_row, mdo): @@ -182,16 +194,18 @@ def test_upsert_datarow_metadata_by_name(data_row, mdo): metadata.data_row_id: metadata for metadata in mdo.bulk_export([data_row.uid]) } - assert len([f for f in metadata_lookup.get(data_row.uid).fields - ]), metadata_lookup.get(data_row.uid).fields + assert len( + [f for f in metadata_lookup.get(data_row.uid).fields] + ), metadata_lookup.get(data_row.uid).fields def test_upsert_datarow_metadata_option_by_name(data_row, mdo): - metadata = DataRowMetadata(data_row_id=data_row.uid, - fields=[ - DataRowMetadataField(name='split', - value='test'), - ]) + metadata = DataRowMetadata( + data_row_id=data_row.uid, + fields=[ + DataRowMetadataField(name="split", value="test"), + ], + ) errors = mdo.bulk_upsert([metadata]) assert len(errors) == 0 @@ -199,16 +213,17 @@ def test_upsert_datarow_metadata_option_by_name(data_row, mdo): assert len(datarows[0].fields) == 1 metadata = datarows[0].fields[0] assert metadata.schema_id == SPLIT_SCHEMA_ID - assert metadata.name == 'test' + assert metadata.name == "test" assert metadata.value == TEST_SPLIT_ID def test_upsert_datarow_metadata_option_by_incorrect_name(data_row, mdo): - metadata = DataRowMetadata(data_row_id=data_row.uid, - fields=[ - DataRowMetadataField(name='split', - value='test1'), - ]) + metadata = DataRowMetadata( + data_row_id=data_row.uid, + fields=[ + DataRowMetadataField(name="split", value="test1"), + ], + ) with pytest.raises(KeyError): mdo.bulk_upsert([metadata]) @@ -216,55 +231,47 @@ def test_upsert_datarow_metadata_option_by_incorrect_name(data_row, mdo): def test_raise_enum_upsert_schema_error(data_row, mdo): """Setting an option id as the schema id will raise a Value Error""" - metadata = DataRowMetadata(data_row_id=data_row.uid, - fields=[ - DataRowMetadataField(schema_id=TEST_SPLIT_ID, - value=SPLIT_SCHEMA_ID), - ]) + metadata = DataRowMetadata( + data_row_id=data_row.uid, + fields=[ + DataRowMetadataField( + schema_id=TEST_SPLIT_ID, value=SPLIT_SCHEMA_ID + ), + ], + ) with pytest.raises(ValueError): mdo.bulk_upsert([metadata]) def test_upsert_non_existent_schema_id(data_row, mdo): """Raise error on non-existent schema id""" - metadata = DataRowMetadata(data_row_id=data_row.uid, - fields=[ - DataRowMetadataField( - schema_id=INVALID_SCHEMA_ID, - value="message"), - ]) + metadata = DataRowMetadata( + data_row_id=data_row.uid, + fields=[ + DataRowMetadataField(schema_id=INVALID_SCHEMA_ID, value="message"), + ], + ) with pytest.raises(ValueError): mdo.bulk_upsert([metadata]) def test_parse_raw_metadata(mdo): example = { - 'dataRowId': - 'ckr6kkfx801ui0yrtg9fje8xh', - 'globalKey': - 'global-key-1', - 'fields': [ - { - 'schemaId': 'cko8s9r5v0001h2dk9elqdidh', - 'value': 'my-new-message' - }, - { - 'schemaId': 'cko8sbczn0002h2dkdaxb5kal', - 'value': {} - }, + "dataRowId": "ckr6kkfx801ui0yrtg9fje8xh", + "globalKey": "global-key-1", + "fields": [ { - 'schemaId': 'cko8sbscr0003h2dk04w86hof', - 'value': {} + "schemaId": "cko8s9r5v0001h2dk9elqdidh", + "value": "my-new-message", }, + {"schemaId": "cko8sbczn0002h2dkdaxb5kal", "value": {}}, + {"schemaId": "cko8sbscr0003h2dk04w86hof", "value": {}}, { - 'schemaId': 'cko8sdzv70006h2dk8jg64zvb', - 'value': '2021-07-20T21:41:14.606710Z' + "schemaId": "cko8sdzv70006h2dk8jg64zvb", + "value": "2021-07-20T21:41:14.606710Z", }, - { - 'schemaId': FAKE_SCHEMA_ID, - 'value': 0.5 - }, - ] + {"schemaId": FAKE_SCHEMA_ID, "value": 0.5}, + ], } parsed = mdo.parse_metadata([example]) @@ -281,26 +288,14 @@ def test_parse_raw_metadata(mdo): def test_parse_raw_metadata_fields(mdo): example = [ + {"schemaId": "cko8s9r5v0001h2dk9elqdidh", "value": "my-new-message"}, + {"schemaId": "cko8sbczn0002h2dkdaxb5kal", "value": {}}, + {"schemaId": "cko8sbscr0003h2dk04w86hof", "value": {}}, { - 'schemaId': 'cko8s9r5v0001h2dk9elqdidh', - 'value': 'my-new-message' - }, - { - 'schemaId': 'cko8sbczn0002h2dkdaxb5kal', - 'value': {} - }, - { - 'schemaId': 'cko8sbscr0003h2dk04w86hof', - 'value': {} - }, - { - 'schemaId': 'cko8sdzv70006h2dk8jg64zvb', - 'value': '2021-07-20T21:41:14.606710Z' - }, - { - 'schemaId': FAKE_SCHEMA_ID, - 'value': 0.5 + "schemaId": "cko8sdzv70006h2dk8jg64zvb", + "value": "2021-07-20T21:41:14.606710Z", }, + {"schemaId": FAKE_SCHEMA_ID, "value": 0.5}, ] parsed = mdo.parse_metadata_fields(example) @@ -312,35 +307,36 @@ def test_parse_raw_metadata_fields(mdo): def test_parse_metadata_schema(): unparsed = { - 'id': - 'cl467a4ec0046076g7s9yheoa', - 'name': - 'enum metadata', - 'kind': - 'CustomMetadataEnum', - 'options': [{ - 'id': 'cl467a4ec0047076ggjneeruy', - 'name': 'option1', - 'kind': 'CustomMetadataEnumOption' - }, { - 'id': 'cl4qa31u0009e078p5m280jer', - 'name': 'option2', - 'kind': 'CustomMetadataEnumOption' - }] + "id": "cl467a4ec0046076g7s9yheoa", + "name": "enum metadata", + "kind": "CustomMetadataEnum", + "options": [ + { + "id": "cl467a4ec0047076ggjneeruy", + "name": "option1", + "kind": "CustomMetadataEnumOption", + }, + { + "id": "cl4qa31u0009e078p5m280jer", + "name": "option2", + "kind": "CustomMetadataEnumOption", + }, + ], } parsed = _parse_metadata_schema(unparsed) - assert parsed.uid == 'cl467a4ec0046076g7s9yheoa' - assert parsed.name == 'enum metadata' + assert parsed.uid == "cl467a4ec0046076g7s9yheoa" + assert parsed.name == "enum metadata" assert parsed.kind == DataRowMetadataKind.enum assert len(parsed.options) == 2 - assert parsed.options[0].uid == 'cl467a4ec0047076ggjneeruy' + assert parsed.options[0].uid == "cl467a4ec0047076ggjneeruy" assert parsed.options[0].kind == DataRowMetadataKind.option def test_create_schema(mdo): metadata_name = str(uuid.uuid4()) - created_schema = mdo.create_schema(metadata_name, DataRowMetadataKind.enum, - ["option 1", "option 2"]) + created_schema = mdo.create_schema( + metadata_name, DataRowMetadataKind.enum, ["option 1", "option 2"] + ) assert created_schema.name == metadata_name assert created_schema.kind == DataRowMetadataKind.enum assert len(created_schema.options) == 2 @@ -350,10 +346,12 @@ def test_create_schema(mdo): def test_update_schema(mdo): metadata_name = str(uuid.uuid4()) - created_schema = mdo.create_schema(metadata_name, DataRowMetadataKind.enum, - ["option 1", "option 2"]) - updated_schema = mdo.update_schema(metadata_name, - f"{metadata_name}_updated") + created_schema = mdo.create_schema( + metadata_name, DataRowMetadataKind.enum, ["option 1", "option 2"] + ) + updated_schema = mdo.update_schema( + metadata_name, f"{metadata_name}_updated" + ) assert updated_schema.name == f"{metadata_name}_updated" assert updated_schema.uid == created_schema.uid assert updated_schema.kind == DataRowMetadataKind.enum @@ -362,10 +360,12 @@ def test_update_schema(mdo): def test_update_enum_options(mdo): metadata_name = str(uuid.uuid4()) - created_schema = mdo.create_schema(metadata_name, DataRowMetadataKind.enum, - ["option 1", "option 2"]) - updated_schema = mdo.update_enum_option(metadata_name, "option 1", - "option 3") + created_schema = mdo.create_schema( + metadata_name, DataRowMetadataKind.enum, ["option 1", "option 2"] + ) + updated_schema = mdo.update_enum_option( + metadata_name, "option 1", "option 3" + ) assert updated_schema.name == metadata_name assert updated_schema.uid == created_schema.uid assert updated_schema.kind == DataRowMetadataKind.enum @@ -376,23 +376,28 @@ def test_update_enum_options(mdo): def test_delete_schema(mdo): metadata_name = str(uuid.uuid4()) - created_schema = mdo.create_schema(metadata_name, - DataRowMetadataKind.string) + created_schema = mdo.create_schema( + metadata_name, DataRowMetadataKind.string + ) status = mdo.delete_schema(created_schema.name) mdo.refresh_ontology() assert status assert metadata_name not in mdo.custom_by_name -@pytest.mark.parametrize('datetime_str', - ['2011-11-04T00:05:23Z', '2011-11-04T00:05:23+00:00']) +@pytest.mark.parametrize( + "datetime_str", ["2011-11-04T00:05:23Z", "2011-11-04T00:05:23+00:00"] +) def test_upsert_datarow_date_metadata(data_row, mdo, datetime_str): metadata = [ - DataRowMetadata(data_row_id=data_row.uid, - fields=[ - DataRowMetadataField(name='captureDateTime', - value=datetime_str), - ]) + DataRowMetadata( + data_row_id=data_row.uid, + fields=[ + DataRowMetadataField( + name="captureDateTime", value=datetime_str + ), + ], + ) ] errors = mdo.bulk_upsert(metadata) assert len(errors) == 0 @@ -401,18 +406,22 @@ def test_upsert_datarow_date_metadata(data_row, mdo, datetime_str): assert f"{metadata[0].fields[0].value}" == "2011-11-04 00:05:23+00:00" -@pytest.mark.parametrize('datetime_str', - ['2011-11-04T00:05:23Z', '2011-11-04T00:05:23+00:00']) +@pytest.mark.parametrize( + "datetime_str", ["2011-11-04T00:05:23Z", "2011-11-04T00:05:23+00:00"] +) def test_create_data_row_with_metadata(dataset, image_url, datetime_str): client = dataset.client assert len(list(dataset.data_rows())) == 0 metadata_fields = [ - DataRowMetadataField(name='captureDateTime', value=datetime_str) + DataRowMetadataField(name="captureDateTime", value=datetime_str) ] - data_row = dataset.create_data_row(row_data=image_url, - metadata_fields=metadata_fields) + data_row = dataset.create_data_row( + row_data=image_url, metadata_fields=metadata_fields + ) retrieved_data_row = client.get_data_row(data_row.uid) - assert f"{retrieved_data_row.metadata[0].value}" == "2011-11-04 00:05:23+00:00" + assert ( + f"{retrieved_data_row.metadata[0].value}" == "2011-11-04 00:05:23+00:00" + ) diff --git a/libs/labelbox/tests/data/test_prefetch_generator.py b/libs/labelbox/tests/data/test_prefetch_generator.py index 2738f3640..b90074a9d 100644 --- a/libs/labelbox/tests/data/test_prefetch_generator.py +++ b/libs/labelbox/tests/data/test_prefetch_generator.py @@ -4,13 +4,12 @@ class ChildClassGenerator(PrefetchGenerator): - def __init__(self, examples, num_executors=1): super().__init__(data=examples, num_executors=num_executors) def _process(self, value): num = random() - if num < .2: + if num < 0.2: raise ValueError("Randomized value error") return value diff --git a/libs/labelbox/tests/integration/conftest.py b/libs/labelbox/tests/integration/conftest.py index 5b1f9aa9a..d37287fe8 100644 --- a/libs/labelbox/tests/integration/conftest.py +++ b/libs/labelbox/tests/integration/conftest.py @@ -17,7 +17,15 @@ from labelbox import Dataset, DataRow from labelbox import LabelingFrontend -from labelbox import OntologyBuilder, Tool, Option, Classification, MediaType, PromptResponseClassification, ResponseOption +from labelbox import ( + OntologyBuilder, + Tool, + Option, + Classification, + MediaType, + PromptResponseClassification, + ResponseOption, +) from labelbox.orm import query from labelbox.pagination import PaginatedCollection from labelbox.schema.annotation_import import LabelImport @@ -46,9 +54,10 @@ def project_based_user(client, rand_gen): newUserId } } - """ % (email, str(client.get_roles()['NONE'].uid)) - user_id = client.execute( - query_str)['addMembersToOrganization'][0]['newUserId'] + """ % (email, str(client.get_roles()["NONE"].uid)) + user_id = client.execute(query_str)["addMembersToOrganization"][0][ + "newUserId" + ] assert user_id is not None, "Unable to add user with old mutation" user = client._get_single(User, user_id) yield user @@ -58,9 +67,12 @@ def project_based_user(client, rand_gen): @pytest.fixture def project_pack(client): projects = [ - client.create_project(name=f"user-proj-{idx}", - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) for idx in range(2) + client.create_project( + name=f"user-proj-{idx}", + queue_mode=QueueMode.Batch, + media_type=MediaType.Image, + ) + for idx in range(2) ] yield projects for proj in projects: @@ -71,15 +83,18 @@ def project_pack(client): def project_with_empty_ontology(project): editor = list( project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] + where=LabelingFrontend.name == "editor" + ) + )[0] empty_ontology = {"tools": [], "classifications": []} project.setup(editor, empty_ontology) yield project @pytest.fixture -def configured_project(project_with_empty_ontology, initial_dataset, rand_gen, - image_url): +def configured_project( + project_with_empty_ontology, initial_dataset, rand_gen, image_url +): dataset = initial_dataset data_row_id = dataset.create_data_row(row_data=image_url).uid project = project_with_empty_ontology @@ -87,7 +102,7 @@ def configured_project(project_with_empty_ontology, initial_dataset, rand_gen, batch = project.create_batch( rand_gen(str), [data_row_id], # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) + 5, # priority between 1(Highest) - 5(lowest) ) project.data_row_ids = [data_row_id] @@ -97,11 +112,14 @@ def configured_project(project_with_empty_ontology, initial_dataset, rand_gen, @pytest.fixture -def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, - image_url): - project = client.create_project(name=rand_gen(str), - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) +def configured_project_with_complex_ontology( + client, initial_dataset, rand_gen, image_url +): + project = client.create_project( + name=rand_gen(str), + queue_mode=QueueMode.Batch, + media_type=MediaType.Image, + ) dataset = initial_dataset data_row = dataset.create_data_row(row_data=image_url) data_row_ids = [data_row.uid] @@ -109,13 +127,15 @@ def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, project.create_batch( rand_gen(str), data_row_ids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) + 5, # priority between 1(Highest) - 5(lowest) ) project.data_row_ids = data_row_ids editor = list( project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] + where=LabelingFrontend.name == "editor" + ) + )[0] ontology = OntologyBuilder() tools = [ @@ -123,24 +143,29 @@ def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, Tool(tool=Tool.Type.LINE, name="test-line-class"), Tool(tool=Tool.Type.POINT, name="test-point-class"), Tool(tool=Tool.Type.POLYGON, name="test-polygon-class"), - Tool(tool=Tool.Type.NER, name="test-ner-class") + Tool(tool=Tool.Type.NER, name="test-ner-class"), ] options = [ Option(value="first option answer"), Option(value="second option answer"), - Option(value="third option answer") + Option(value="third option answer"), ] classifications = [ - Classification(class_type=Classification.Type.TEXT, - name="test-text-class"), - Classification(class_type=Classification.Type.RADIO, - name="test-radio-class", - options=options), - Classification(class_type=Classification.Type.CHECKLIST, - name="test-checklist-class", - options=options) + Classification( + class_type=Classification.Type.TEXT, name="test-text-class" + ), + Classification( + class_type=Classification.Type.RADIO, + name="test-radio-class", + options=options, + ), + Classification( + class_type=Classification.Type.CHECKLIST, + name="test-checklist-class", + options=options, + ), ] for t in tools: @@ -161,19 +186,22 @@ def ontology(client): ontology_builder = OntologyBuilder( tools=[ Tool(tool=Tool.Type.BBOX, name="Box 1", color="#ff0000"), - Tool(tool=Tool.Type.BBOX, name="Box 2", color="#ff0000") + Tool(tool=Tool.Type.BBOX, name="Box 2", color="#ff0000"), ], classifications=[ - Classification(name="Root Class", - class_type=Classification.Type.RADIO, - options=[ - Option(value="1", label="Option 1"), - Option(value="2", label="Option 2") - ]) - ]) - ontology = client.create_ontology('Integration Test Ontology', - ontology_builder.asdict(), - MediaType.Image) + Classification( + name="Root Class", + class_type=Classification.Type.RADIO, + options=[ + Option(value="1", label="Option 1"), + Option(value="2", label="Option 2"), + ], + ) + ], + ) + ontology = client.create_ontology( + "Integration Test Ontology", ontology_builder.asdict(), MediaType.Image + ) yield ontology client.delete_unused_ontology(ontology.uid) @@ -191,12 +219,9 @@ def video_data(client, rand_gen, video_data_row, wait_for_data_row_processing): def create_video_data_row(rand_gen): return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4", - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4-{rand_gen(str)}", - "media_type": - "VIDEO", + "row_data": "https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4", + "global_key": f"https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4-{rand_gen(str)}", + "media_type": "VIDEO", } @@ -218,25 +243,25 @@ def video_data_row(rand_gen): class ExportV2Helpers: - @classmethod - def run_project_export_v2_task(cls, - project, - num_retries=5, - task_name=None, - filters={}, - params={}): + def run_project_export_v2_task( + cls, project, num_retries=5, task_name=None, filters={}, params={} + ): task = None - params = params if params else { - "project_details": True, - "performance_details": False, - "data_row_details": True, - "label_details": True - } - while (num_retries > 0): - task = project.export_v2(task_name=task_name, - filters=filters, - params=params) + params = ( + params + if params + else { + "project_details": True, + "performance_details": False, + "data_row_details": True, + "label_details": True, + } + ) + while num_retries > 0: + task = project.export_v2( + task_name=task_name, filters=filters, params=params + ) task.wait_till_done() assert task.status == "COMPLETE" assert task.errors is None @@ -248,21 +273,19 @@ def run_project_export_v2_task(cls, return task.result @classmethod - def run_dataset_export_v2_task(cls, - dataset, - num_retries=5, - task_name=None, - filters={}, - params={}): + def run_dataset_export_v2_task( + cls, dataset, num_retries=5, task_name=None, filters={}, params={} + ): task = None - params = params if params else { - "performance_details": False, - "label_details": True - } - while (num_retries > 0): - task = dataset.export_v2(task_name=task_name, - filters=filters, - params=params) + params = ( + params + if params + else {"performance_details": False, "label_details": True} + ) + while num_retries > 0: + task = dataset.export_v2( + task_name=task_name, filters=filters, params=params + ) task.wait_till_done() assert task.status == "COMPLETE" assert task.errors is None @@ -275,23 +298,20 @@ def run_dataset_export_v2_task(cls, return task.result @classmethod - def run_catalog_export_v2_task(cls, - client, - num_retries=5, - task_name=None, - filters={}, - params={}): + def run_catalog_export_v2_task( + cls, client, num_retries=5, task_name=None, filters={}, params={} + ): task = None - params = params if params else { - "performance_details": False, - "label_details": True - } + params = ( + params + if params + else {"performance_details": False, "label_details": True} + ) catalog = client.get_catalog() - while (num_retries > 0): - - task = catalog.export_v2(task_name=task_name, - filters=filters, - params=params) + while num_retries > 0: + task = catalog.export_v2( + task_name=task_name, filters=filters, params=params + ) task.wait_till_done() assert task.status == "COMPLETE" assert task.errors is None @@ -317,9 +337,10 @@ def big_dataset_data_row_ids(big_dataset: Dataset): yield [dr.json["data_row"]["id"] for dr in stream] -@pytest.fixture(scope='function') -def dataset_with_invalid_data_rows(unique_dataset: Dataset, - upload_invalid_data_rows_for_dataset): +@pytest.fixture(scope="function") +def dataset_with_invalid_data_rows( + unique_dataset: Dataset, upload_invalid_data_rows_for_dataset +): upload_invalid_data_rows_for_dataset(unique_dataset) yield unique_dataset @@ -327,29 +348,33 @@ def dataset_with_invalid_data_rows(unique_dataset: Dataset, @pytest.fixture def upload_invalid_data_rows_for_dataset(): - def _upload_invalid_data_rows_for_dataset(dataset: Dataset): - task = dataset.create_data_rows([ - { - "row_data": 'gs://invalid-bucket/example.png', # forbidden - "external_id": "image-without-access.jpg" - }, - ] * 2) + task = dataset.create_data_rows( + [ + { + "row_data": "gs://invalid-bucket/example.png", # forbidden + "external_id": "image-without-access.jpg", + }, + ] + * 2 + ) task.wait_till_done() return _upload_invalid_data_rows_for_dataset @pytest.fixture -def prompt_response_generation_project_with_new_dataset(client: Client, - rand_gen, request): +def prompt_response_generation_project_with_new_dataset( + client: Client, rand_gen, request +): """fixture is parametrize and needs project_type in request""" media_type = request.param prompt_response_project = client.create_prompt_response_generation_project( name=f"{media_type.value}-{rand_gen(str)}", dataset_name=f"{media_type.value}-{rand_gen(str)}", data_row_count=1, - media_type=media_type) + media_type=media_type, + ) yield prompt_response_project @@ -357,15 +382,17 @@ def prompt_response_generation_project_with_new_dataset(client: Client, @pytest.fixture -def prompt_response_generation_project_with_dataset_id(client: Client, dataset, - rand_gen, request): +def prompt_response_generation_project_with_dataset_id( + client: Client, dataset, rand_gen, request +): """fixture is parametrized and needs project_type in request""" media_type = request.param prompt_response_project = client.create_prompt_response_generation_project( name=f"{media_type.value}-{rand_gen(str)}", dataset_id=dataset.uid, data_row_count=1, - media_type=media_type) + media_type=media_type, + ) yield prompt_response_project @@ -384,10 +411,10 @@ def response_creation_project(client: Client, rand_gen): @pytest.fixture def prompt_response_features(rand_gen): - prompt_text = PromptResponseClassification( class_type=PromptResponseClassification.Type.PROMPT, - name=f"{rand_gen(str)}-prompt text") + name=f"{rand_gen(str)}-prompt text", + ) response_radio = PromptResponseClassification( class_type=PromptResponseClassification.Type.RESPONSE_RADIO, @@ -395,27 +422,33 @@ def prompt_response_features(rand_gen): options=[ ResponseOption(value=f"{rand_gen(str)}-first radio option answer"), ResponseOption(value=f"{rand_gen(str)}-second radio option answer"), - ]) + ], + ) response_checklist = PromptResponseClassification( class_type=PromptResponseClassification.Type.RESPONSE_CHECKLIST, name=f"{rand_gen(str)}-response checklist classification", options=[ ResponseOption( - value=f"{rand_gen(str)}-first checklist option answer"), + value=f"{rand_gen(str)}-first checklist option answer" + ), ResponseOption( - value=f"{rand_gen(str)}-second checklist option answer"), - ]) + value=f"{rand_gen(str)}-second checklist option answer" + ), + ], + ) response_text_with_char = PromptResponseClassification( class_type=PromptResponseClassification.Type.RESPONSE_TEXT, name=f"{rand_gen(str)}-response text with character min and max", character_min=1, - character_max=10) + character_max=10, + ) response_text = PromptResponseClassification( class_type=PromptResponseClassification.Type.RESPONSE_TEXT, - name=f"{rand_gen(str)}-response text") + name=f"{rand_gen(str)}-response text", + ) nested_response_radio = PromptResponseClassification( class_type=PromptResponseClassification.Type.RESPONSE_RADIO, @@ -425,54 +458,65 @@ def prompt_response_features(rand_gen): f"{rand_gen(str)}-first_radio_answer", options=[ PromptResponseClassification( - class_type=PromptResponseClassification.Type. - RESPONSE_RADIO, + class_type=PromptResponseClassification.Type.RESPONSE_RADIO, name=f"{rand_gen(str)}-sub_radio_question", options=[ ResponseOption( - f"{rand_gen(str)}-first_sub_radio_answer") - ]) - ]) - ]) + f"{rand_gen(str)}-first_sub_radio_answer" + ) + ], + ) + ], + ) + ], + ) yield { "prompts": [prompt_text], "responses": [ - response_text, response_radio, response_checklist, - response_text_with_char, nested_response_radio - ] + response_text, + response_radio, + response_checklist, + response_text_with_char, + nested_response_radio, + ], } @pytest.fixture -def prompt_response_ontology(client: Client, rand_gen, prompt_response_features, - request): +def prompt_response_ontology( + client: Client, rand_gen, prompt_response_features, request +): """fixture is parametrize and needs project_type in request""" project_type = request.param if project_type == MediaType.LLMPromptCreation: ontology_builder = OntologyBuilder( - tools=[], classifications=prompt_response_features["prompts"]) + tools=[], classifications=prompt_response_features["prompts"] + ) elif project_type == MediaType.LLMPromptResponseCreation: ontology_builder = OntologyBuilder( tools=[], - classifications=prompt_response_features["prompts"] + - prompt_response_features["responses"]) + classifications=prompt_response_features["prompts"] + + prompt_response_features["responses"], + ) else: ontology_builder = OntologyBuilder( - tools=[], classifications=prompt_response_features["responses"]) + tools=[], classifications=prompt_response_features["responses"] + ) ontology_name = f"prompt-response-{rand_gen(str)}" if project_type in MediaType: - ontology = client.create_ontology(ontology_name, - ontology_builder.asdict(), - media_type=project_type) + ontology = client.create_ontology( + ontology_name, ontology_builder.asdict(), media_type=project_type + ) else: ontology = client.create_ontology( ontology_name, ontology_builder.asdict(), media_type=MediaType.Text, - ontology_kind=OntologyKind.ResponseCreation) + ontology_kind=OntologyKind.ResponseCreation, + ) yield ontology featureSchemaIds = [ @@ -503,7 +547,8 @@ def feature_schema(client, point): yield created_feature_schema client.delete_unused_feature_schema( - created_feature_schema.normalized['featureSchemaId']) + created_feature_schema.normalized["featureSchemaId"] + ) @pytest.fixture @@ -511,55 +556,75 @@ def chat_evaluation_ontology(client, rand_gen): ontology_name = f"test-chat-evaluation-ontology-{rand_gen(str)}" ontology_builder = OntologyBuilder( tools=[ - Tool(tool=Tool.Type.MESSAGE_SINGLE_SELECTION, - name="model output single selection"), - Tool(tool=Tool.Type.MESSAGE_MULTI_SELECTION, - name="model output multi selection"), - Tool(tool=Tool.Type.MESSAGE_RANKING, - name="model output multi ranking"), + Tool( + tool=Tool.Type.MESSAGE_SINGLE_SELECTION, + name="model output single selection", + ), + Tool( + tool=Tool.Type.MESSAGE_MULTI_SELECTION, + name="model output multi selection", + ), + Tool( + tool=Tool.Type.MESSAGE_RANKING, + name="model output multi ranking", + ), ], classifications=[ - Classification(class_type=Classification.Type.TEXT, - name="global model output text classification", - scope=Classification.Scope.GLOBAL), - Classification(class_type=Classification.Type.RADIO, - name="global model output radio classification", - scope=Classification.Scope.GLOBAL, - options=[ - Option(value="global first option answer"), - Option(value="global second option answer"), - ]), - Classification(class_type=Classification.Type.CHECKLIST, - name="global model output checklist classification", - scope=Classification.Scope.GLOBAL, - options=[ - Option(value="global first option answer"), - Option(value="global second option answer"), - ]), - Classification(class_type=Classification.Type.TEXT, - name="index model output text classification", - scope=Classification.Scope.INDEX), - Classification(class_type=Classification.Type.RADIO, - name="index model output radio classification", - scope=Classification.Scope.INDEX, - options=[ - Option(value="index first option answer"), - Option(value="index second option answer"), - ]), - Classification(class_type=Classification.Type.CHECKLIST, - name="index model output checklist classification", - scope=Classification.Scope.INDEX, - options=[ - Option(value="index first option answer"), - Option(value="index second option answer"), - ]), - ]) + Classification( + class_type=Classification.Type.TEXT, + name="global model output text classification", + scope=Classification.Scope.GLOBAL, + ), + Classification( + class_type=Classification.Type.RADIO, + name="global model output radio classification", + scope=Classification.Scope.GLOBAL, + options=[ + Option(value="global first option answer"), + Option(value="global second option answer"), + ], + ), + Classification( + class_type=Classification.Type.CHECKLIST, + name="global model output checklist classification", + scope=Classification.Scope.GLOBAL, + options=[ + Option(value="global first option answer"), + Option(value="global second option answer"), + ], + ), + Classification( + class_type=Classification.Type.TEXT, + name="index model output text classification", + scope=Classification.Scope.INDEX, + ), + Classification( + class_type=Classification.Type.RADIO, + name="index model output radio classification", + scope=Classification.Scope.INDEX, + options=[ + Option(value="index first option answer"), + Option(value="index second option answer"), + ], + ), + Classification( + class_type=Classification.Type.CHECKLIST, + name="index model output checklist classification", + scope=Classification.Scope.INDEX, + options=[ + Option(value="index first option answer"), + Option(value="index second option answer"), + ], + ), + ], + ) ontology = client.create_ontology( ontology_name, ontology_builder.asdict(), media_type=MediaType.Conversational, - ontology_kind=OntologyKind.ModelEvaluation) + ontology_kind=OntologyKind.ModelEvaluation, + ) yield ontology @@ -573,9 +638,9 @@ def chat_evaluation_ontology(client, rand_gen): def live_chat_evaluation_project_with_new_dataset(client, rand_gen): project_name = f"test-model-evaluation-project-{rand_gen(str)}" dataset_name = f"test-model-evaluation-dataset-{rand_gen(str)}" - project = client.create_model_evaluation_project(name=project_name, - dataset_name=dataset_name, - data_row_count=1) + project = client.create_model_evaluation_project( + name=project_name, dataset_name=dataset_name, data_row_count=1 + ) yield project @@ -596,9 +661,9 @@ def offline_chat_evaluation_project(client, rand_gen): def chat_evaluation_project_append_to_dataset(client, dataset, rand_gen): project_name = f"test-model-evaluation-project-{rand_gen(str)}" dataset_id = dataset.uid - project = client.create_model_evaluation_project(name=project_name, - dataset_id=dataset_id, - data_row_count=1) + project = client.create_model_evaluation_project( + name=project_name, dataset_id=dataset_id, data_row_count=1 + ) yield project @@ -613,106 +678,102 @@ def offline_conversational_data_row(initial_dataset): "actors": { "clxhs9wk000013b6w7imiz0h8": { "role": "human", - "metadata": { - "name": "User" - } + "metadata": {"name": "User"}, }, "clxhsc6xb00013b6w1awh579j": { "role": "model", "metadata": { "modelConfigId": "5a50d319-56bd-405d-87bb-4442daea0d0f" - } + }, }, "clxhsc6xb00023b6wlp0768zs": { "role": "model", "metadata": { "modelConfigId": "1cfc833a-2684-47df-95ac-bb7d9f9e3e1f" - } - } + }, + }, }, "messages": { "clxhs9wk000023b6wrufora3k": { "actorId": "clxhs9wk000013b6w7imiz0h8", - "content": [{ - "type": "text", - "content": "Hello world" - }], - "childMessageIds": ["clxhscb4z00033b6wukpvmuol"] + "content": [{"type": "text", "content": "Hello world"}], + "childMessageIds": ["clxhscb4z00033b6wukpvmuol"], }, "clxhscb4z00033b6wukpvmuol": { "actorId": "clxhsc6xb00013b6w1awh579j", - "content": [{ - "type": - "text", - "content": - "Hello to you too! 👋 \n\nIt's great to be your guide in the digital world. What can I help you with today? 😊 \n" - }], - "childMessageIds": ["clxhu2s0900013b6wbv0ndddd"] + "content": [ + { + "type": "text", + "content": "Hello to you too! 👋 \n\nIt's great to be your guide in the digital world. What can I help you with today? 😊 \n", + } + ], + "childMessageIds": ["clxhu2s0900013b6wbv0ndddd"], }, "clxhu2s0900013b6wbv0ndddd": { - "actorId": - "clxhs9wk000013b6w7imiz0h8", - "content": [{ - "type": "text", - "content": "Lets some some multi-turn happening" - }], + "actorId": "clxhs9wk000013b6w7imiz0h8", + "content": [ + { + "type": "text", + "content": "Lets some some multi-turn happening", + } + ], "childMessageIds": [ - "clxhu4qib00023b6wuep47b1l", "clxhu4qib00033b6wf18az01q" - ] + "clxhu4qib00023b6wuep47b1l", + "clxhu4qib00033b6wf18az01q", + ], }, "clxhu4qib00023b6wuep47b1l": { "actorId": "clxhsc6xb00013b6w1awh579j", - "content": [{ - "type": - "text", - "content": - "Okay, I'm ready for some multi-turn fun! To make it interesting, how about we try building a story together? \n\n**Here's the beginning:**\n\nThe old, dusty book lay forgotten on the shelf, its leather cover cracked and faded. But as the afternoon sun slanted through the window, a single ray caught a glint of gold on the book's spine. Suddenly...\n\n**Now you tell me what happens!** What does the glint of gold turn out to be? What happens next? 🤔 \n" - }], - "childMessageIds": ["clxhu596m00043b6wvkgahcwz"] + "content": [ + { + "type": "text", + "content": "Okay, I'm ready for some multi-turn fun! To make it interesting, how about we try building a story together? \n\n**Here's the beginning:**\n\nThe old, dusty book lay forgotten on the shelf, its leather cover cracked and faded. But as the afternoon sun slanted through the window, a single ray caught a glint of gold on the book's spine. Suddenly...\n\n**Now you tell me what happens!** What does the glint of gold turn out to be? What happens next? 🤔 \n", + } + ], + "childMessageIds": ["clxhu596m00043b6wvkgahcwz"], }, "clxhu4qib00033b6wf18az01q": { "actorId": "clxhsc6xb00023b6wlp0768zs", - "content": [{ - "type": - "text", - "content": - "Okay, I'm game! Let's have a multi-turn conversation! But to get us started, give me a topic or a question to get us going! 😄 \n\nWhat's on your mind? \n" - }], - "childMessageIds": ["clxhu596m00043b6wvkgahcwz"] + "content": [ + { + "type": "text", + "content": "Okay, I'm game! Let's have a multi-turn conversation! But to get us started, give me a topic or a question to get us going! 😄 \n\nWhat's on your mind? \n", + } + ], + "childMessageIds": ["clxhu596m00043b6wvkgahcwz"], }, "clxhu596m00043b6wvkgahcwz": { - "actorId": - "clxhs9wk000013b6w7imiz0h8", - "content": [{ - "type": "text", - "content": "Can you generate an image?" - }], + "actorId": "clxhs9wk000013b6w7imiz0h8", + "content": [ + {"type": "text", "content": "Can you generate an image?"} + ], "childMessageIds": [ - "clxhu5hqk00053b6wk21u6xi7", "clxhu5hqk00063b6whl4gofx8" - ] + "clxhu5hqk00053b6wk21u6xi7", + "clxhu5hqk00063b6whl4gofx8", + ], }, "clxhu5hqk00053b6wk21u6xi7": { "actorId": "clxhsc6xb00013b6w1awh579j", - "content": [{ - "type": - "text", - "content": - "I can't generate images directly yet. I'm a text-based AI, so my superpowers are all about words! 🪄 \n\nHowever, I'd love to help you describe an image that fits our story. Tell me: \n\n* **What should the image show?** (For example: The dusty book, the glint of gold, the bookshelf, the room, etc.) \n* **What kind of feeling or atmosphere should the image have?** (Mysterious? Magical? Ominous?) \n\nThe more details you give me, the better I can describe the image! ✨ \n" - }], - "childMessageIds": [] + "content": [ + { + "type": "text", + "content": "I can't generate images directly yet. I'm a text-based AI, so my superpowers are all about words! 🪄 \n\nHowever, I'd love to help you describe an image that fits our story. Tell me: \n\n* **What should the image show?** (For example: The dusty book, the glint of gold, the bookshelf, the room, etc.) \n* **What kind of feeling or atmosphere should the image have?** (Mysterious? Magical? Ominous?) \n\nThe more details you give me, the better I can describe the image! ✨ \n", + } + ], + "childMessageIds": [], }, "clxhu5hqk00063b6whl4gofx8": { "actorId": "clxhsc6xb00023b6wlp0768zs", - "content": [{ - "type": - "text", - "content": - "I can't *actually* generate images directly. 😔 I'm primarily a text-based AI. \n\nTo help me understand what you'd like to see, tell me: \n\n* **What should be in the image?** Be specific! (e.g., \"A cat wearing a tiny hat\", \"A futuristic cityscape at sunset\")\n* **What style do you imagine?** (e.g., realistic, cartoonish, abstract)\n\nOnce you give me those details, I can try to give you a vivid description that's almost as good as seeing it! 😊 \n" - }], - "childMessageIds": [] - } + "content": [ + { + "type": "text", + "content": "I can't *actually* generate images directly. 😔 I'm primarily a text-based AI. \n\nTo help me understand what you'd like to see, tell me: \n\n* **What should be in the image?** Be specific! (e.g., \"A cat wearing a tiny hat\", \"A futuristic cityscape at sunset\")\n* **What style do you imagine?** (e.g., realistic, cartoonish, abstract)\n\nOnce you give me those details, I can try to give you a vivid description that's almost as good as seeing it! 😊 \n", + } + ], + "childMessageIds": [], + }, }, - "rootMessageIds": ["clxhs9wk000023b6wrufora3k"] + "rootMessageIds": ["clxhs9wk000023b6wrufora3k"], } convo_v2_asset = { @@ -734,10 +795,8 @@ def response_data_row(initial_dataset): @pytest.fixture() def conversation_data_row(initial_dataset, rand_gen): data = { - "row_data": - "https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json", - "global_key": - f"https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json-{rand_gen(str)}", + "row_data": "https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json", + "global_key": f"https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json-{rand_gen(str)}", } convo_asset = {"row_data": data} data_row = initial_dataset.create_data_row(convo_asset) @@ -760,16 +819,19 @@ def pytest_fixture_setup(fixturedef): pytest.report[fixturedef.argname] += exec_time -@pytest.fixture(scope='session', autouse=True) +@pytest.fixture(scope="session", autouse=True) def print_perf_summary(): yield if "FIXTURE_PROFILE" in os.environ: sorted_dict = dict( - sorted(pytest.report.items(), - key=lambda item: item[1], - reverse=True)) + sorted( + pytest.report.items(), key=lambda item: item[1], reverse=True + ) + ) num_of_entries = 10 if len(sorted_dict) >= 10 else len(sorted_dict) - slowest_fixtures = [(aaa, sorted_dict[aaa]) - for aaa in islice(sorted_dict, num_of_entries)] + slowest_fixtures = [ + (aaa, sorted_dict[aaa]) + for aaa in islice(sorted_dict, num_of_entries) + ] print("\nTop slowest fixtures:\n", slowest_fixtures, file=sys.stderr) diff --git a/libs/labelbox/tests/integration/schema/test_user_group.py b/libs/labelbox/tests/integration/schema/test_user_group.py index 27882e2d7..6aebd4e89 100644 --- a/libs/labelbox/tests/integration/schema/test_user_group.py +++ b/libs/labelbox/tests/integration/schema/test_user_group.py @@ -3,10 +3,15 @@ from uuid import uuid4 from labelbox import Client from labelbox.schema.user_group import UserGroup, UserGroupColor -from labelbox.exceptions import ResourceNotFoundError, ResourceCreationError, UnprocessableEntityError +from labelbox.exceptions import ( + ResourceNotFoundError, + ResourceCreationError, + UnprocessableEntityError, +) data = faker.Faker() + @pytest.fixture def user_group(client): group_name = data.name() @@ -141,7 +146,7 @@ def test_cannot_update_group_id(user_group): def test_get_user_groups_with_creation_deletion(client): user_group = None - try: + try: # Get all user groups user_groups = list(UserGroup(client).get_user_groups()) @@ -167,7 +172,9 @@ def test_get_user_groups_with_creation_deletion(client): user_groups_post_deletion = list(UserGroup(client).get_user_groups()) - assert len(user_groups_post_deletion) == len(user_groups_post_creation) - 1 + assert ( + len(user_groups_post_deletion) == len(user_groups_post_creation) - 1 + ) finally: if user_group: @@ -217,4 +224,5 @@ def test_throw_error_delete_user_group_no_id(user_group, client): if __name__ == "__main__": import subprocess - subprocess.call(["pytest", "-v", __file__]) \ No newline at end of file + + subprocess.call(["pytest", "-v", __file__]) diff --git a/libs/labelbox/tests/integration/test_batch.py b/libs/labelbox/tests/integration/test_batch.py index d5e3b7a0f..3f9e720a3 100644 --- a/libs/labelbox/tests/integration/test_batch.py +++ b/libs/labelbox/tests/integration/test_batch.py @@ -4,7 +4,12 @@ import pytest from labelbox import Dataset, Project -from labelbox.exceptions import ProcessingWaitTimeout, MalformedQueryException, ResourceConflict, LabelboxError +from labelbox.exceptions import ( + ProcessingWaitTimeout, + MalformedQueryException, + ResourceConflict, + LabelboxError, +) def get_data_row_ids(ds: Dataset): @@ -12,13 +17,12 @@ def get_data_row_ids(ds: Dataset): def test_create_batch(project: Project, big_dataset_data_row_ids: List[str]): - batch = project.create_batch("test-batch", - big_dataset_data_row_ids, - 3, - consensus_settings={ - 'number_of_labels': 3, - 'coverage_percentage': 0.1 - }) + batch = project.create_batch( + "test-batch", + big_dataset_data_row_ids, + 3, + consensus_settings={"number_of_labels": 3, "coverage_percentage": 0.1}, + ) assert batch.name == "test-batch" assert batch.size == len(big_dataset_data_row_ids) @@ -27,86 +31,101 @@ def test_create_batch(project: Project, big_dataset_data_row_ids: List[str]): def test_create_batch_with_invalid_data_rows_ids(project: Project): with pytest.raises(MalformedQueryException) as ex: - project.create_batch("test-batch", data_rows=['a', 'b', 'c']) - assert str( - ex) == "No valid data rows to be added from the list provided!" - - -def test_create_batch_with_the_same_name(project: Project, - small_dataset: Dataset): - batch1 = project.create_batch("batch1", - data_rows=get_data_row_ids(small_dataset)) + project.create_batch("test-batch", data_rows=["a", "b", "c"]) + assert ( + str(ex) == "No valid data rows to be added from the list provided!" + ) + + +def test_create_batch_with_the_same_name( + project: Project, small_dataset: Dataset +): + batch1 = project.create_batch( + "batch1", data_rows=get_data_row_ids(small_dataset) + ) assert batch1.name == "batch1" with pytest.raises(ResourceConflict): - project.create_batch("batch1", - data_rows=get_data_row_ids(small_dataset)) + project.create_batch( + "batch1", data_rows=get_data_row_ids(small_dataset) + ) -def test_create_batch_with_same_data_row_ids(project: Project, - small_dataset: Dataset): - batch1 = project.create_batch("batch1", - data_rows=get_data_row_ids(small_dataset)) +def test_create_batch_with_same_data_row_ids( + project: Project, small_dataset: Dataset +): + batch1 = project.create_batch( + "batch1", data_rows=get_data_row_ids(small_dataset) + ) assert batch1.name == "batch1" with pytest.raises(MalformedQueryException) as ex: - project.create_batch("batch2", - data_rows=get_data_row_ids(small_dataset)) + project.create_batch( + "batch2", data_rows=get_data_row_ids(small_dataset) + ) assert str(ex) == "No valid data rows to add to project" def test_create_batch_with_non_existent_global_keys(project: Project): with pytest.raises(MalformedQueryException) as ex: project.create_batch("batch1", global_keys=["key1"]) - assert str( - ex - ) == "Data rows with the following global keys do not exist: key1." + assert ( + str(ex) + == "Data rows with the following global keys do not exist: key1." + ) -def test_create_batch_with_string_priority(project: Project, - small_dataset: Dataset): +def test_create_batch_with_string_priority( + project: Project, small_dataset: Dataset +): with pytest.raises(LabelboxError): - project.create_batch("batch1", - data_rows=get_data_row_ids(small_dataset), - priority="abcd") + project.create_batch( + "batch1", data_rows=get_data_row_ids(small_dataset), priority="abcd" + ) -def test_create_batch_with_null_priority(project: Project, - small_dataset: Dataset): +def test_create_batch_with_null_priority( + project: Project, small_dataset: Dataset +): with pytest.raises(LabelboxError): - project.create_batch("batch1", - data_rows=get_data_row_ids(small_dataset), - priority=None) + project.create_batch( + "batch1", data_rows=get_data_row_ids(small_dataset), priority=None + ) -def test_create_batch_async(project: Project, - big_dataset_data_row_ids: List[str]): - batch = project._create_batch_async("big-batch", - big_dataset_data_row_ids, - priority=3) +def test_create_batch_async( + project: Project, big_dataset_data_row_ids: List[str] +): + batch = project._create_batch_async( + "big-batch", big_dataset_data_row_ids, priority=3 + ) assert batch.name == "big-batch" assert batch.size == len(big_dataset_data_row_ids) assert len([dr for dr in batch.failed_data_row_ids]) == 0 -def test_create_batch_with_consensus_settings(project: Project, - small_dataset: Dataset): +def test_create_batch_with_consensus_settings( + project: Project, small_dataset: Dataset +): export_task = small_dataset.export() export_task.wait_till_done() stream = export_task.get_buffered_stream() data_rows = [dr.json["data_row"]["id"] for dr in stream] consensus_settings = {"coverage_percentage": 0.1, "number_of_labels": 3} - batch = project.create_batch("batch with consensus settings", - data_rows, - 3, - consensus_settings=consensus_settings) + batch = project.create_batch( + "batch with consensus settings", + data_rows, + 3, + consensus_settings=consensus_settings, + ) assert batch.name == "batch with consensus settings" assert batch.size == len(data_rows) assert batch.consensus_settings == consensus_settings -def test_create_batch_with_data_row_class(project: Project, - small_dataset: Dataset): +def test_create_batch_with_data_row_class( + project: Project, small_dataset: Dataset +): export_task = small_dataset.export() export_task.wait_till_done() stream = export_task.get_buffered_stream() @@ -121,11 +140,11 @@ def test_archive_batch(project: Project, small_dataset: Dataset): export_task.wait_till_done() stream = export_task.get_buffered_stream() data_rows = [dr.json["data_row"]["id"] for dr in stream] - + batch = project.create_batch("batch to archive", data_rows) batch.remove_queued_data_rows() overview = project.get_overview() - + assert overview.to_label == 0 @@ -145,8 +164,9 @@ def test_batch_project(project: Project, small_dataset: Dataset): export_task.wait_till_done() stream = export_task.get_buffered_stream() data_rows = [dr.json["data_row"]["id"] for dr in stream] - batch = project.create_batch("batch to test project relationship", - data_rows) + batch = project.create_batch( + "batch to test project relationship", data_rows + ) project_from_batch = batch.project() @@ -155,8 +175,10 @@ def test_batch_project(project: Project, small_dataset: Dataset): def test_batch_creation_for_data_rows_with_issues( - project: Project, small_dataset: Dataset, - dataset_with_invalid_data_rows: Dataset): + project: Project, + small_dataset: Dataset, + dataset_with_invalid_data_rows: Dataset, +): """ Create a batch containing both valid and invalid data rows """ @@ -167,8 +189,9 @@ def test_batch_creation_for_data_rows_with_issues( data_rows_to_add = valid_data_rows + invalid_data_rows assert len(data_rows_to_add) == 4 - batch = project.create_batch("batch to test failed data rows", - data_rows_to_add) + batch = project.create_batch( + "batch to test failed data rows", data_rows_to_add + ) failed_data_row_ids = [x for x in batch.failed_data_row_ids] assert len(failed_data_row_ids) == 2 @@ -178,8 +201,11 @@ def test_batch_creation_for_data_rows_with_issues( def test_batch_creation_with_processing_timeout( - project: Project, small_dataset: Dataset, unique_dataset: Dataset, - upload_invalid_data_rows_for_dataset): + project: Project, + small_dataset: Dataset, + unique_dataset: Dataset, + upload_invalid_data_rows_for_dataset, +): """ Create a batch with zero wait time, this means that the waiting logic will throw exception immediately """ @@ -202,15 +228,16 @@ def test_batch_creation_with_processing_timeout( @pytest.mark.export_v1("export_v1 test remove later") -def test_export_data_rows(project: Project, dataset: Dataset, image_url: str, - external_id: str): +def test_export_data_rows( + project: Project, dataset: Dataset, image_url: str, external_id: str +): n_data_rows = 2 - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": external_id - }, - ] * n_data_rows) + task = dataset.create_data_rows( + [ + {"row_data": image_url, "external_id": external_id}, + ] + * n_data_rows + ) task.wait_till_done() data_rows = [dr.uid for dr in list(dataset.export_data_rows())] @@ -227,10 +254,10 @@ def test_list_all_batches(project: Project, client, image_url: str): Test to verify that we can retrieve all available batches in the project. """ # Data to use - img_assets = [{ - "row_data": image_url, - "external_id": str(uuid4()) - } for asset in range(0, 2)] + img_assets = [ + {"row_data": image_url, "external_id": str(uuid4())} + for asset in range(0, 2) + ] data = [img_assets for _ in range(0, 2)] # Setup @@ -245,8 +272,9 @@ def test_list_all_batches(project: Project, client, image_url: str): for dataset in datasets: data_row_ids = get_data_row_ids(dataset) - new_batch = project.create_batch(name=str(uuid4()), - data_rows=data_row_ids) + new_batch = project.create_batch( + name=str(uuid4()), data_rows=data_row_ids + ) batches.append(new_batch) # Test @@ -269,7 +297,8 @@ def test_list_project_batches_with_no_batches(project: Project): @pytest.mark.skip( reason="Test cannot be used effectively with MAL/LabelImport. \ -Fix/Unskip after resolving deletion with MAL/LabelImport") +Fix/Unskip after resolving deletion with MAL/LabelImport" +) def test_delete_labels(project, small_dataset): export_task = small_dataset.export() export_task.wait_till_done() @@ -280,14 +309,16 @@ def test_delete_labels(project, small_dataset): @pytest.mark.skip( reason="Test cannot be used effectively with MAL/LabelImport. \ -Fix/Unskip after resolving deletion with MAL/LabelImport") +Fix/Unskip after resolving deletion with MAL/LabelImport" +) def test_delete_labels_with_templates(project: Project, small_dataset: Dataset): export_task = small_dataset.export() export_task.wait_till_done() stream = export_task.get_buffered_stream() data_rows = [dr.json["data_row"]["id"] for dr in stream] - batch = project.create_batch("batch to delete labels w templates", - data_rows) + batch = project.create_batch( + "batch to delete labels w templates", data_rows + ) export_task = project.export(filters={"batch_ids": [batch.uid]}) export_task.wait_till_done() diff --git a/libs/labelbox/tests/integration/test_batches.py b/libs/labelbox/tests/integration/test_batches.py index 5c24a65f0..cabae4053 100644 --- a/libs/labelbox/tests/integration/test_batches.py +++ b/libs/labelbox/tests/integration/test_batches.py @@ -6,9 +6,9 @@ def test_create_batches(project: Project, big_dataset_data_row_ids: List[str]): - task = project.create_batches("test-batch", - big_dataset_data_row_ids, - priority=3) + task = project.create_batches( + "test-batch", big_dataset_data_row_ids, priority=3 + ) task.wait_till_done() assert task.errors() is None @@ -26,9 +26,9 @@ def test_create_batches_from_dataset(project: Project, big_dataset: Dataset): data_rows = [dr.json["data_row"]["id"] for dr in stream] project._wait_until_data_rows_are_processed(data_rows, [], 300) - task = project.create_batches_from_dataset("test-batch", - big_dataset.uid, - priority=3) + task = project.create_batches_from_dataset( + "test-batch", big_dataset.uid, priority=3 + ) task.wait_till_done() assert task.errors() is None diff --git a/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py b/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py index aafcddbcc..47e39e2cf 100644 --- a/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py +++ b/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py @@ -7,9 +7,12 @@ def test_create_chat_evaluation_ontology_project( - client, chat_evaluation_ontology, - live_chat_evaluation_project_with_new_dataset, - offline_conversational_data_row, rand_gen): + client, + chat_evaluation_ontology, + live_chat_evaluation_project_with_new_dataset, + offline_conversational_data_row, + rand_gen, +): ontology = chat_evaluation_ontology # here we are essentially testing the ontology creation which is a fixture @@ -20,7 +23,7 @@ def test_create_chat_evaluation_ontology_project( assert tool.schema_id assert tool.feature_schema_id - assert (len(ontology.classifications()) == 6) + assert len(ontology.classifications()) == 6 for classification in ontology.classifications(): assert classification.schema_id assert classification.feature_schema_id @@ -34,29 +37,32 @@ def test_create_chat_evaluation_ontology_project( assert project.ontology().name == ontology.name with pytest.raises( - ValueError, - match="Cannot create batches for auto data generation projects"): + ValueError, + match="Cannot create batches for auto data generation projects", + ): project.create_batch( rand_gen(str), [offline_conversational_data_row.uid], # sample of data row objects ) with pytest.raises( - ValueError, - match="Cannot create batches for auto data generation projects"): - with patch('labelbox.schema.project.MAX_SYNC_BATCH_ROW_COUNT', - new=0): # force to async - + ValueError, + match="Cannot create batches for auto data generation projects", + ): + with patch( + "labelbox.schema.project.MAX_SYNC_BATCH_ROW_COUNT", new=0 + ): # force to async project.create_batch( rand_gen(str), - [offline_conversational_data_row.uid + [ + offline_conversational_data_row.uid ], # sample of data row objects ) def test_create_chat_evaluation_ontology_project_existing_dataset( - client, chat_evaluation_ontology, - chat_evaluation_project_append_to_dataset): + client, chat_evaluation_ontology, chat_evaluation_project_append_to_dataset +): ontology = chat_evaluation_ontology project = chat_evaluation_project_append_to_dataset @@ -69,31 +75,35 @@ def test_create_chat_evaluation_ontology_project_existing_dataset( @pytest.fixture def tools_json(): - tools = [{ - 'tool': 'message-single-selection', - 'name': 'model output single selection', - 'required': False, - 'color': '#ff0000', - 'classifications': [], - 'schemaNodeId': None, - 'featureSchemaId': None - }, { - 'tool': 'message-multi-selection', - 'name': 'model output multi selection', - 'required': False, - 'color': '#00ff00', - 'classifications': [], - 'schemaNodeId': None, - 'featureSchemaId': None - }, { - 'tool': 'message-ranking', - 'name': 'model output multi ranking', - 'required': False, - 'color': '#0000ff', - 'classifications': [], - 'schemaNodeId': None, - 'featureSchemaId': None - }] + tools = [ + { + "tool": "message-single-selection", + "name": "model output single selection", + "required": False, + "color": "#ff0000", + "classifications": [], + "schemaNodeId": None, + "featureSchemaId": None, + }, + { + "tool": "message-multi-selection", + "name": "model output multi selection", + "required": False, + "color": "#00ff00", + "classifications": [], + "schemaNodeId": None, + "featureSchemaId": None, + }, + { + "tool": "message-ranking", + "name": "model output multi ranking", + "required": False, + "color": "#0000ff", + "classifications": [], + "schemaNodeId": None, + "featureSchemaId": None, + }, + ] return tools @@ -124,19 +134,21 @@ def ontology_from_feature_ids(client, features_from_json): client.delete_unused_ontology(ontology.uid) -def test_ontology_create_feature_schema(ontology_from_feature_ids, - features_from_json, tools_json): +def test_ontology_create_feature_schema( + ontology_from_feature_ids, features_from_json, tools_json +): created_ontology = ontology_from_feature_ids feature_schema_ids = {f.uid for f in features_from_json} - tools_normalized = created_ontology.normalized['tools'] + tools_normalized = created_ontology.normalized["tools"] tools = tools_json for tool in tools: generated_tool = next( - t for t in tools_normalized if t['name'] == tool['name']) - assert generated_tool['schemaNodeId'] is not None - assert generated_tool['featureSchemaId'] in feature_schema_ids - assert generated_tool['tool'] == tool['tool'] - assert generated_tool['name'] == tool['name'] - assert generated_tool['required'] == tool['required'] - assert generated_tool['color'] == tool['color'] + t for t in tools_normalized if t["name"] == tool["name"] + ) + assert generated_tool["schemaNodeId"] is not None + assert generated_tool["featureSchemaId"] in feature_schema_ids + assert generated_tool["tool"] == tool["tool"] + assert generated_tool["name"] == tool["name"] + assert generated_tool["required"] == tool["required"] + assert generated_tool["color"] == tool["color"] diff --git a/libs/labelbox/tests/integration/test_client_errors.py b/libs/labelbox/tests/integration/test_client_errors.py index 411b9e3b0..64b8fb626 100644 --- a/libs/labelbox/tests/integration/test_client_errors.py +++ b/libs/labelbox/tests/integration/test_client_errors.py @@ -40,7 +40,7 @@ def test_syntax_error(client): def test_semantic_error(client): with pytest.raises(labelbox.exceptions.InvalidQueryError) as excinfo: client.execute("query {bbb {id}}", check_naming=False) - assert excinfo.value.message.startswith("Cannot query field \"bbb\"") + assert excinfo.value.message.startswith('Cannot query field "bbb"') def test_timeout_error(client, project): @@ -59,8 +59,9 @@ def test_timeout_error(client, project): def test_query_complexity_error(client): with pytest.raises(labelbox.exceptions.ValidationFailedError) as excinfo: - client.execute("{projects {datasets {dataRows {labels {id}}}}}", - check_naming=False) + client.execute( + "{projects {datasets {dataRows {labels {id}}}}}", check_naming=False + ) assert excinfo.value.message == "Query complexity limit exceeded" @@ -70,8 +71,9 @@ def test_resource_not_found_error(client): def test_network_error(client): - client = labelbox.client.Client(api_key=client.api_key, - endpoint="not_a_valid_URL") + client = labelbox.client.Client( + api_key=client.api_key, endpoint="not_a_valid_URL" + ) with pytest.raises(labelbox.exceptions.NetworkError) as excinfo: client.create_project(name="Project name") @@ -103,7 +105,6 @@ def test_invalid_attribute_error( @pytest.mark.skip("timeouts cause failure before rate limit") def test_api_limit_error(client): - def get(arg): try: return client.get_user() diff --git a/libs/labelbox/tests/integration/test_data_row_delete_metadata.py b/libs/labelbox/tests/integration/test_data_row_delete_metadata.py index 8674beb33..2df860181 100644 --- a/libs/labelbox/tests/integration/test_data_row_delete_metadata.py +++ b/libs/labelbox/tests/integration/test_data_row_delete_metadata.py @@ -5,7 +5,12 @@ from labelbox import DataRow, Dataset, Client, DataRowMetadataOntology from labelbox.exceptions import MalformedQueryException -from labelbox.schema.data_row_metadata import DataRowMetadataField, DataRowMetadata, DataRowMetadataKind, DeleteDataRowMetadata +from labelbox.schema.data_row_metadata import ( + DataRowMetadataField, + DataRowMetadata, + DataRowMetadataKind, + DeleteDataRowMetadata, +) from labelbox.schema.identifiable import GlobalKey, UniqueId INVALID_SCHEMA_ID = "1" * 25 @@ -16,13 +21,13 @@ TEST_SPLIT_ID = "cko8scbz70005h2dkastwhgqt" TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh" CAPTURE_DT_SCHEMA_ID = "cko8sdzv70006h2dk8jg64zvb" -CUSTOM_TEXT_SCHEMA_NAME = 'custom_text' +CUSTOM_TEXT_SCHEMA_NAME = "custom_text" FAKE_NUMBER_FIELD = { "id": FAKE_SCHEMA_ID, "name": "number", - "kind": 'CustomMetadataNumber', - "reserved": False + "kind": "CustomMetadataNumber", + "reserved": False, } @@ -42,13 +47,16 @@ def mdo(client: Client): @pytest.fixture def big_dataset(dataset: Dataset, image_url): - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image", - "global_key": str(uuid.uuid4()) - }, - ] * 5) + task = dataset.create_data_rows( + [ + { + "row_data": image_url, + "external_id": "my-image", + "global_key": str(uuid.uuid4()), + }, + ] + * 5 + ) task.wait_till_done() yield dataset @@ -62,11 +70,13 @@ def make_metadata(dr_id: str = None, gk: str = None) -> DataRowMetadata: global_key=gk, data_row_id=dr_id, fields=[ - DataRowMetadataField(schema_id=SPLIT_SCHEMA_ID, - value=TEST_SPLIT_ID), + DataRowMetadataField( + schema_id=SPLIT_SCHEMA_ID, value=TEST_SPLIT_ID + ), DataRowMetadataField(schema_id=CAPTURE_DT_SCHEMA_ID, value=time), DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value=msg), - ]) + ], + ) return metadata @@ -74,15 +84,14 @@ def make_named_metadata(dr_id) -> DataRowMetadata: msg = "A message" time = datetime.now(timezone.utc) - metadata = DataRowMetadata(data_row_id=dr_id, - fields=[ - DataRowMetadataField(name='split', - value=TEST_SPLIT_ID), - DataRowMetadataField(name='captureDateTime', - value=time), - DataRowMetadataField( - name=CUSTOM_TEXT_SCHEMA_NAME, value=msg), - ]) + metadata = DataRowMetadata( + data_row_id=dr_id, + fields=[ + DataRowMetadataField(name="split", value=TEST_SPLIT_ID), + DataRowMetadataField(name="captureDateTime", value=time), + DataRowMetadataField(name=CUSTOM_TEXT_SCHEMA_NAME, value=msg), + ], + ) return metadata @@ -94,9 +103,11 @@ def test_bulk_delete_datarow_metadata(data_row, mdo): assert len(mdo.bulk_export([data_row.uid])[0].fields) upload_ids = [m.schema_id for m in metadata.fields[:-2]] mdo.bulk_delete( - [DeleteDataRowMetadata(data_row_id=data_row.uid, fields=upload_ids)]) + [DeleteDataRowMetadata(data_row_id=data_row.uid, fields=upload_ids)] + ) remaining_ids = set( - [f.schema_id for f in mdo.bulk_export([data_row.uid])[0].fields]) + [f.schema_id for f in mdo.bulk_export([data_row.uid])[0].fields] + ) assert not len(remaining_ids.intersection(set(upload_ids))) @@ -116,43 +127,55 @@ def data_row_id_as_str(data_row): @pytest.mark.parametrize( - 'data_row_for_delete', - ['data_row_id_as_str', 'data_row_unique_id', 'data_row_global_key']) -def test_bulk_delete_datarow_metadata(data_row_for_delete, data_row, mdo, - request): + "data_row_for_delete", + ["data_row_id_as_str", "data_row_unique_id", "data_row_global_key"], +) +def test_bulk_delete_datarow_metadata( + data_row_for_delete, data_row, mdo, request +): """test bulk deletes for all fields""" metadata = make_metadata(data_row.uid) mdo.bulk_upsert([metadata]) assert len(mdo.bulk_export([data_row.uid])[0].fields) upload_ids = [m.schema_id for m in metadata.fields[:-2]] - mdo.bulk_delete([ - DeleteDataRowMetadata( - data_row_id=request.getfixturevalue(data_row_for_delete), - fields=upload_ids) - ]) + mdo.bulk_delete( + [ + DeleteDataRowMetadata( + data_row_id=request.getfixturevalue(data_row_for_delete), + fields=upload_ids, + ) + ] + ) remaining_ids = set( - [f.schema_id for f in mdo.bulk_export([data_row.uid])[0].fields]) + [f.schema_id for f in mdo.bulk_export([data_row.uid])[0].fields] + ) assert not len(remaining_ids.intersection(set(upload_ids))) @pytest.mark.parametrize( - 'data_row_for_delete', - ['data_row_id_as_str', 'data_row_unique_id', 'data_row_global_key']) -def test_bulk_partial_delete_datarow_metadata(data_row_for_delete, data_row, - mdo, request): + "data_row_for_delete", + ["data_row_id_as_str", "data_row_unique_id", "data_row_global_key"], +) +def test_bulk_partial_delete_datarow_metadata( + data_row_for_delete, data_row, mdo, request +): """Delete a single from metadata""" n_fields = len(mdo.bulk_export([data_row.uid])[0].fields) metadata = make_metadata(data_row.uid) mdo.bulk_upsert([metadata]) - assert len(mdo.bulk_export( - [data_row.uid])[0].fields) == (n_fields + len(metadata.fields)) + assert len(mdo.bulk_export([data_row.uid])[0].fields) == ( + n_fields + len(metadata.fields) + ) - mdo.bulk_delete([ - DeleteDataRowMetadata( - data_row_id=request.getfixturevalue(data_row_for_delete), - fields=[TEXT_SCHEMA_ID]) - ]) + mdo.bulk_delete( + [ + DeleteDataRowMetadata( + data_row_id=request.getfixturevalue(data_row_for_delete), + fields=[TEXT_SCHEMA_ID], + ) + ] + ) fields = [f for f in mdo.bulk_export([data_row.uid])[0].fields] assert len(fields) == (len(metadata.fields) - 1) @@ -166,7 +189,9 @@ def data_row_unique_ids(big_dataset): deletes.append( DeleteDataRowMetadata( data_row_id=UniqueId(data_row_id), - fields=[SPLIT_SCHEMA_ID, CAPTURE_DT_SCHEMA_ID])) + fields=[SPLIT_SCHEMA_ID, CAPTURE_DT_SCHEMA_ID], + ) + ) return deletes @@ -179,7 +204,9 @@ def data_row_ids_as_str(big_dataset): deletes.append( DeleteDataRowMetadata( data_row_id=data_row_id, - fields=[SPLIT_SCHEMA_ID, CAPTURE_DT_SCHEMA_ID])) + fields=[SPLIT_SCHEMA_ID, CAPTURE_DT_SCHEMA_ID], + ) + ) return deletes @@ -192,26 +219,35 @@ def data_row_global_keys(big_dataset): deletes.append( DeleteDataRowMetadata( data_row_id=GlobalKey(data_row_id), - fields=[SPLIT_SCHEMA_ID, CAPTURE_DT_SCHEMA_ID])) + fields=[SPLIT_SCHEMA_ID, CAPTURE_DT_SCHEMA_ID], + ) + ) return deletes @pytest.mark.parametrize( - 'data_rows_for_delete', - ['data_row_ids_as_str', 'data_row_unique_ids', 'data_row_global_keys']) -def test_large_bulk_delete_datarow_metadata(data_rows_for_delete, big_dataset, - mdo, request): + "data_rows_for_delete", + ["data_row_ids_as_str", "data_row_unique_ids", "data_row_global_keys"], +) +def test_large_bulk_delete_datarow_metadata( + data_rows_for_delete, big_dataset, mdo, request +): metadata = [] data_row_ids = [dr.uid for dr in big_dataset.data_rows()] for data_row_id in data_row_ids: metadata.append( - DataRowMetadata(data_row_id=data_row_id, - fields=[ - DataRowMetadataField(schema_id=SPLIT_SCHEMA_ID, - value=TEST_SPLIT_ID), - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, - value="test-message") - ])) + DataRowMetadata( + data_row_id=data_row_id, + fields=[ + DataRowMetadataField( + schema_id=SPLIT_SCHEMA_ID, value=TEST_SPLIT_ID + ), + DataRowMetadataField( + schema_id=TEXT_SCHEMA_ID, value="test-message" + ), + ], + ) + ) errors = mdo.bulk_upsert(metadata) assert len(errors) == 0 @@ -221,7 +257,7 @@ def test_large_bulk_delete_datarow_metadata(data_rows_for_delete, big_dataset, assert len(errors) == len(data_row_ids) for error in errors: assert error.fields == [CAPTURE_DT_SCHEMA_ID] - assert error.error == 'Schema did not exist' + assert error.error == "Schema did not exist" for data_row_id in data_row_ids: fields = [f for f in mdo.bulk_export([data_row_id])[0].fields] @@ -230,10 +266,15 @@ def test_large_bulk_delete_datarow_metadata(data_rows_for_delete, big_dataset, @pytest.mark.parametrize( - 'data_row_for_delete', - ['data_row_id_as_str', 'data_row_unique_id', 'data_row_global_key']) -def test_bulk_delete_datarow_enum_metadata(data_row_for_delete, - data_row: DataRow, mdo: DataRowMetadataOntology, request): + "data_row_for_delete", + ["data_row_id_as_str", "data_row_unique_id", "data_row_global_key"], +) +def test_bulk_delete_datarow_enum_metadata( + data_row_for_delete, + data_row: DataRow, + mdo: DataRowMetadataOntology, + request, +): """test bulk deletes for non non fields""" metadata = make_metadata(data_row.uid) metadata.fields = [ @@ -243,28 +284,39 @@ def test_bulk_delete_datarow_enum_metadata(data_row_for_delete, exported = mdo.bulk_export([data_row.uid])[0].fields assert len(exported) == len( - set([x.schema_id for x in metadata.fields] + - [x.schema_id for x in exported])) - - mdo.bulk_delete([ - DeleteDataRowMetadata( - data_row_id=request.getfixturevalue(data_row_for_delete), - fields=[SPLIT_SCHEMA_ID]) - ]) + set( + [x.schema_id for x in metadata.fields] + + [x.schema_id for x in exported] + ) + ) + + mdo.bulk_delete( + [ + DeleteDataRowMetadata( + data_row_id=request.getfixturevalue(data_row_for_delete), + fields=[SPLIT_SCHEMA_ID], + ) + ] + ) exported = mdo.bulk_export([data_row.uid])[0].fields assert len(exported) == 0 @pytest.mark.parametrize( - 'data_row_for_delete', - ['data_row_id_as_str', 'data_row_unique_id', 'data_row_global_key']) -def test_delete_non_existent_schema_id(data_row_for_delete, data_row, mdo, - request): - res = mdo.bulk_delete([ - DeleteDataRowMetadata( - data_row_id=request.getfixturevalue(data_row_for_delete), - fields=[SPLIT_SCHEMA_ID]) - ]) + "data_row_for_delete", + ["data_row_id_as_str", "data_row_unique_id", "data_row_global_key"], +) +def test_delete_non_existent_schema_id( + data_row_for_delete, data_row, mdo, request +): + res = mdo.bulk_delete( + [ + DeleteDataRowMetadata( + data_row_id=request.getfixturevalue(data_row_for_delete), + fields=[SPLIT_SCHEMA_ID], + ) + ] + ) assert len(res) == 1 assert res[0].fields == [SPLIT_SCHEMA_ID] - assert res[0].error == 'Schema did not exist' + assert res[0].error == "Schema did not exist" diff --git a/libs/labelbox/tests/integration/test_data_rows.py b/libs/labelbox/tests/integration/test_data_rows.py index 454d55b87..7f69c2995 100644 --- a/libs/labelbox/tests/integration/test_data_rows.py +++ b/libs/labelbox/tests/integration/test_data_rows.py @@ -10,16 +10,26 @@ from labelbox.schema.media_type import MediaType from labelbox import DataRow, AssetAttachment -from labelbox.exceptions import MalformedQueryException, ResourceCreationError, InvalidQueryError +from labelbox.exceptions import ( + MalformedQueryException, + ResourceCreationError, + InvalidQueryError, +) from labelbox.schema.task import Task, DataUpsertTask -from labelbox.schema.data_row_metadata import DataRowMetadataField, DataRowMetadataKind +from labelbox.schema.data_row_metadata import ( + DataRowMetadataField, + DataRowMetadataKind, +) SPLIT_SCHEMA_ID = "cko8sbczn0002h2dkdaxb5kal" TEST_SPLIT_ID = "cko8scbz70005h2dkastwhgqt" TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh" CAPTURE_DT_SCHEMA_ID = "cko8sdzv70006h2dk8jg64zvb" EXPECTED_METADATA_SCHEMA_IDS = [ - SPLIT_SCHEMA_ID, TEST_SPLIT_ID, TEXT_SCHEMA_ID, CAPTURE_DT_SCHEMA_ID + SPLIT_SCHEMA_ID, + TEST_SPLIT_ID, + TEXT_SCHEMA_ID, + CAPTURE_DT_SCHEMA_ID, ].sort() CUSTOM_TEXT_SCHEMA_NAME = "custom_text" @@ -40,20 +50,19 @@ def mdo(client): @pytest.fixture def conversational_content(): return { - 'row_data': { - "messages": [{ - "messageId": "message-0", - "timestampUsec": 1530718491, - "content": "I love iphone! i just bought new iphone! 🥰 📲", - "user": { - "userId": "Bot 002", - "name": "Bot" - }, - "align": "left", - "canLabel": False - }], + "row_data": { + "messages": [ + { + "messageId": "message-0", + "timestampUsec": 1530718491, + "content": "I love iphone! i just bought new iphone! 🥰 📲", + "user": {"userId": "Bot 002", "name": "Bot"}, + "align": "left", + "canLabel": False, + } + ], "version": 1, - "type": "application/vnd.labelbox.conversational" + "type": "application/vnd.labelbox.conversational", } } @@ -62,27 +71,24 @@ def conversational_content(): def tile_content(): return { "row_data": { - "tileLayerUrl": - "https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/{z}/{x}/{y}.png", - "bounds": [[19.405662413477728, -99.21052827588443], - [19.400498983095076, -99.20534818927473]], - "minZoom": - 12, - "maxZoom": - 20, - "epsg": - "EPSG4326", - "alternativeLayers": [{ - "tileLayerUrl": - "https://api.mapbox.com/styles/v1/mapbox/satellite-streets-v11/tiles/{z}/{x}/{y}?access_token=pk.eyJ1IjoibWFwYm94IiwiYSI6ImNpejY4NXVycTA2emYycXBndHRqcmZ3N3gifQ.rJcFIG214AriISLbB6B5aw", - "name": - "Satellite" - }, { - "tileLayerUrl": - "https://api.mapbox.com/styles/v1/mapbox/navigation-guidance-night-v4/tiles/{z}/{x}/{y}?access_token=pk.eyJ1IjoibWFwYm94IiwiYSI6ImNpejY4NXVycTA2emYycXBndHRqcmZ3N3gifQ.rJcFIG214AriISLbB6B5aw", - "name": - "Guidance" - }] + "tileLayerUrl": "https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/{z}/{x}/{y}.png", + "bounds": [ + [19.405662413477728, -99.21052827588443], + [19.400498983095076, -99.20534818927473], + ], + "minZoom": 12, + "maxZoom": 20, + "epsg": "EPSG4326", + "alternativeLayers": [ + { + "tileLayerUrl": "https://api.mapbox.com/styles/v1/mapbox/satellite-streets-v11/tiles/{z}/{x}/{y}?access_token=pk.eyJ1IjoibWFwYm94IiwiYSI6ImNpejY4NXVycTA2emYycXBndHRqcmZ3N3gifQ.rJcFIG214AriISLbB6B5aw", + "name": "Satellite", + }, + { + "tileLayerUrl": "https://api.mapbox.com/styles/v1/mapbox/navigation-guidance-night-v4/tiles/{z}/{x}/{y}?access_token=pk.eyJ1IjoibWFwYm94IiwiYSI6ImNpejY4NXVycTA2emYycXBndHRqcmZ3N3gifQ.rJcFIG214AriISLbB6B5aw", + "name": "Guidance", + }, + ], } } @@ -103,16 +109,11 @@ def make_metadata_fields_dict(): msg = "A message" time = datetime.utcnow() - fields = [{ - "schema_id": SPLIT_SCHEMA_ID, - "value": TEST_SPLIT_ID - }, { - "schema_id": CAPTURE_DT_SCHEMA_ID, - "value": time - }, { - "schema_id": TEXT_SCHEMA_ID, - "value": msg - }] + fields = [ + {"schema_id": SPLIT_SCHEMA_ID, "value": TEST_SPLIT_ID}, + {"schema_id": CAPTURE_DT_SCHEMA_ID, "value": time}, + {"schema_id": TEXT_SCHEMA_ID, "value": msg}, + ] return fields @@ -133,9 +134,9 @@ def test_create_invalid_aws_data_row(dataset, client): assert "s3" in exc.value.message with pytest.raises(InvalidQueryError) as exc: - dataset.create_data_rows([{ - "row_data": "s3://labelbox-public-data/invalid" - }]) + dataset.create_data_rows( + [{"row_data": "s3://labelbox-public-data/invalid"}] + ) assert "s3" in exc.value.message @@ -176,15 +177,12 @@ def test_data_row_bulk_creation(dataset, rand_gen, image_url): try: payload = [ - { - DataRow.row_data: image_url - }, - { - "row_data": image_url - }, + {DataRow.row_data: image_url}, + {"row_data": image_url}, ] - with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES', - new=300): # To make 2 chunks + with patch( + "labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES", new=300 + ): # To make 2 chunks # Test creation using URL task = dataset.create_data_rows(payload, file_upload_thread_count=2) task.wait_till_done() @@ -225,10 +223,12 @@ def local_image_file(image_url) -> NamedTemporaryFile: def test_data_row_bulk_creation_from_file(dataset, local_image_file, image_url): - with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES', - new=500): # Force chunking + with patch( + "labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES", new=500 + ): # Force chunking task = dataset.create_data_rows( - [local_image_file.name, local_image_file.name]) + [local_image_file.name, local_image_file.name] + ) task.wait_till_done() assert task.status == "COMPLETE" assert len(task.result) == 2 @@ -239,16 +239,17 @@ def test_data_row_bulk_creation_from_file(dataset, local_image_file, image_url): def test_data_row_bulk_creation_from_row_data_file_external_id( - dataset, local_image_file, image_url): - with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES', - new=500): # Force chunking - task = dataset.create_data_rows([{ - "row_data": local_image_file.name, - 'external_id': 'some_name' - }, { - "row_data": image_url, - 'external_id': 'some_name2' - }]) + dataset, local_image_file, image_url +): + with patch( + "labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES", new=500 + ): # Force chunking + task = dataset.create_data_rows( + [ + {"row_data": local_image_file.name, "external_id": "some_name"}, + {"row_data": image_url, "external_id": "some_name2"}, + ] + ) task.wait_till_done() assert task.status == "COMPLETE" assert len(task.result) == 2 @@ -259,15 +260,18 @@ def test_data_row_bulk_creation_from_row_data_file_external_id( assert image_url in row_data -def test_data_row_bulk_creation_from_row_data_file(dataset, rand_gen, - local_image_file, image_url): - with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES', - new=500): # Force chunking - task = dataset.create_data_rows([{ - "row_data": local_image_file.name - }, { - "row_data": local_image_file.name - }]) +def test_data_row_bulk_creation_from_row_data_file( + dataset, rand_gen, local_image_file, image_url +): + with patch( + "labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES", new=500 + ): # Force chunking + task = dataset.create_data_rows( + [ + {"row_data": local_image_file.name}, + {"row_data": local_image_file.name}, + ] + ) task.wait_till_done() assert task.status == "COMPLETE" assert len(task.result) == 2 @@ -285,9 +289,9 @@ def test_data_row_large_bulk_creation(dataset, image_url): with NamedTemporaryFile() as fp: fp.write("Test data".encode()) fp.flush() - task = dataset.create_data_rows([{ - DataRow.row_data: image_url - }] * n_urls + [fp.name] * n_local) + task = dataset.create_data_rows( + [{DataRow.row_data: image_url}] * n_urls + [fp.name] * n_local + ) task.wait_till_done() assert task.status == "COMPLETE" assert len(list(dataset.data_rows())) == n_local + n_urls @@ -302,8 +306,10 @@ def test_data_row_single_creation(dataset, rand_gen, image_url): assert data_row.dataset() == dataset assert data_row.created_by() == client.get_user() assert data_row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(data_row.row_data).content + assert ( + requests.get(image_url).content + == requests.get(data_row.row_data).content + ) assert data_row.media_attributes is not None assert data_row.global_key is None @@ -325,8 +331,10 @@ def test_create_data_row_with_dict(dataset, image_url): assert data_row.dataset() == dataset assert data_row.created_by() == client.get_user() assert data_row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(data_row.row_data).content + assert ( + requests.get(image_url).content + == requests.get(data_row.row_data).content + ) assert data_row.media_attributes is not None @@ -339,8 +347,10 @@ def test_create_data_row_with_dict_containing_field(dataset, image_url): assert data_row.dataset() == dataset assert data_row.created_by() == client.get_user() assert data_row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(data_row.row_data).content + assert ( + requests.get(image_url).content + == requests.get(data_row.row_data).content + ) assert data_row.media_attributes is not None @@ -353,8 +363,10 @@ def test_create_data_row_with_dict_unpacked(dataset, image_url): assert data_row.dataset() == dataset assert data_row.created_by() == client.get_user() assert data_row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(data_row.row_data).content + assert ( + requests.get(image_url).content + == requests.get(data_row.row_data).content + ) assert data_row.media_attributes is not None @@ -367,22 +379,26 @@ def test_create_data_row_with_metadata(mdo, dataset, image_url): client = dataset.client assert len(list(dataset.data_rows())) == 0 - data_row = dataset.create_data_row(row_data=image_url, - metadata_fields=make_metadata_fields()) + data_row = dataset.create_data_row( + row_data=image_url, metadata_fields=make_metadata_fields() + ) assert len(list(dataset.data_rows())) == 1 assert data_row.dataset() == dataset assert data_row.created_by() == client.get_user() assert data_row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(data_row.row_data).content + assert ( + requests.get(image_url).content + == requests.get(data_row.row_data).content + ) assert data_row.media_attributes is not None metadata_fields = data_row.metadata_fields metadata = data_row.metadata assert len(metadata_fields) == 3 assert len(metadata) == 3 - assert [m["schemaId"] for m in metadata_fields - ].sort() == EXPECTED_METADATA_SCHEMA_IDS + assert [ + m["schemaId"] for m in metadata_fields + ].sort() == EXPECTED_METADATA_SCHEMA_IDS for m in metadata: assert mdo._parse_upsert(m) @@ -392,21 +408,25 @@ def test_create_data_row_with_metadata_dict(mdo, dataset, image_url): assert len(list(dataset.data_rows())) == 0 data_row = dataset.create_data_row( - row_data=image_url, metadata_fields=make_metadata_fields_dict()) + row_data=image_url, metadata_fields=make_metadata_fields_dict() + ) assert len(list(dataset.data_rows())) == 1 assert data_row.dataset() == dataset assert data_row.created_by() == client.get_user() assert data_row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(data_row.row_data).content + assert ( + requests.get(image_url).content + == requests.get(data_row.row_data).content + ) assert data_row.media_attributes is not None metadata_fields = data_row.metadata_fields metadata = data_row.metadata assert len(metadata_fields) == 3 assert len(metadata) == 3 - assert [m["schemaId"] for m in metadata_fields - ].sort() == EXPECTED_METADATA_SCHEMA_IDS + assert [ + m["schemaId"] for m in metadata_fields + ].sort() == EXPECTED_METADATA_SCHEMA_IDS for m in metadata: assert mdo._parse_upsert(m) @@ -415,7 +435,8 @@ def test_create_data_row_with_invalid_metadata(dataset, image_url): fields = make_metadata_fields() # make the payload invalid by providing the same schema id more than once fields.append( - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value='some msg')) + DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value="some msg") + ) with pytest.raises(ResourceCreationError): dataset.create_data_row(row_data=image_url, metadata_fields=fields) @@ -425,28 +446,30 @@ def test_create_data_rows_with_metadata(mdo, dataset, image_url): client = dataset.client assert len(list(dataset.data_rows())) == 0 - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url, - DataRow.external_id: "row1", - DataRow.metadata_fields: make_metadata_fields() - }, - { - DataRow.row_data: image_url, - DataRow.external_id: "row2", - "metadata_fields": make_metadata_fields() - }, - { - DataRow.row_data: image_url, - DataRow.external_id: "row3", - DataRow.metadata_fields: make_metadata_fields_dict() - }, - { - DataRow.row_data: image_url, - DataRow.external_id: "row4", - "metadata_fields": make_metadata_fields_dict() - }, - ]) + task = dataset.create_data_rows( + [ + { + DataRow.row_data: image_url, + DataRow.external_id: "row1", + DataRow.metadata_fields: make_metadata_fields(), + }, + { + DataRow.row_data: image_url, + DataRow.external_id: "row2", + "metadata_fields": make_metadata_fields(), + }, + { + DataRow.row_data: image_url, + DataRow.external_id: "row3", + DataRow.metadata_fields: make_metadata_fields_dict(), + }, + { + DataRow.row_data: image_url, + DataRow.external_id: "row4", + "metadata_fields": make_metadata_fields_dict(), + }, + ] + ) task.wait_till_done() assert len(list(dataset.data_rows())) == 4 @@ -455,63 +478,60 @@ def test_create_data_rows_with_metadata(mdo, dataset, image_url): assert row.dataset() == dataset assert row.created_by() == client.get_user() assert row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(row.row_data).content + assert ( + requests.get(image_url).content + == requests.get(row.row_data).content + ) assert row.media_attributes is not None metadata_fields = row.metadata_fields metadata = row.metadata assert len(metadata_fields) == 3 assert len(metadata) == 3 - assert [m["schemaId"] for m in metadata_fields - ].sort() == EXPECTED_METADATA_SCHEMA_IDS + assert [ + m["schemaId"] for m in metadata_fields + ].sort() == EXPECTED_METADATA_SCHEMA_IDS for m in metadata: assert mdo._parse_upsert(m) -@pytest.mark.parametrize("test_function,metadata_obj_type", - [("create_data_rows", "class"), - ("create_data_rows", "dict"), - ("create_data_rows_sync", "class"), - ("create_data_rows_sync", "dict"), - ("create_data_row", "class"), - ("create_data_row", "dict")]) +@pytest.mark.parametrize( + "test_function,metadata_obj_type", + [ + ("create_data_rows", "class"), + ("create_data_rows", "dict"), + ("create_data_rows_sync", "class"), + ("create_data_rows_sync", "dict"), + ("create_data_row", "class"), + ("create_data_row", "dict"), + ], +) def test_create_data_rows_with_named_metadata_field_class( - test_function, metadata_obj_type, mdo, dataset, image_url): - + test_function, metadata_obj_type, mdo, dataset, image_url +): row_with_metadata_field = { - DataRow.row_data: - image_url, - DataRow.external_id: - "row1", + DataRow.row_data: image_url, + DataRow.external_id: "row1", DataRow.metadata_fields: [ - DataRowMetadataField(name='split', value='test'), - DataRowMetadataField(name=CUSTOM_TEXT_SCHEMA_NAME, value='hello') - ] + DataRowMetadataField(name="split", value="test"), + DataRowMetadataField(name=CUSTOM_TEXT_SCHEMA_NAME, value="hello"), + ], } row_with_metadata_dict = { - DataRow.row_data: - image_url, - DataRow.external_id: - "row2", + DataRow.row_data: image_url, + DataRow.external_id: "row2", "metadata_fields": [ - { - 'name': 'split', - 'value': 'test' - }, - { - 'name': CUSTOM_TEXT_SCHEMA_NAME, - 'value': 'hello' - }, - ] + {"name": "split", "value": "test"}, + {"name": CUSTOM_TEXT_SCHEMA_NAME, "value": "hello"}, + ], } assert len(list(dataset.data_rows())) == 0 METADATA_FIELDS = { "class": row_with_metadata_field, - "dict": row_with_metadata_dict + "dict": row_with_metadata_dict, } def create_data_row(data_rows): @@ -520,7 +540,7 @@ def create_data_row(data_rows): CREATION_FUNCTION = { "create_data_rows": dataset.create_data_rows, "create_data_rows_sync": dataset.create_data_rows_sync, - "create_data_row": create_data_row + "create_data_row": create_data_row, } data_rows = [METADATA_FIELDS[metadata_obj_type]] function_to_test = CREATION_FUNCTION[test_function] @@ -536,30 +556,33 @@ def create_data_row(data_rows): metadata = created_rows[0].metadata assert metadata[0].schema_id == SPLIT_SCHEMA_ID - assert metadata[0].name == 'test' - assert metadata[0].value == mdo.reserved_by_name['split']['test'].uid + assert metadata[0].name == "test" + assert metadata[0].value == mdo.reserved_by_name["split"]["test"].uid assert metadata[1].name == CUSTOM_TEXT_SCHEMA_NAME - assert metadata[1].value == 'hello' - assert metadata[1].schema_id == mdo.custom_by_name[ - CUSTOM_TEXT_SCHEMA_NAME].uid + assert metadata[1].value == "hello" + assert ( + metadata[1].schema_id == mdo.custom_by_name[CUSTOM_TEXT_SCHEMA_NAME].uid + ) def test_create_data_rows_with_invalid_metadata(dataset, image_url): fields = make_metadata_fields() # make the payload invalid by providing the same schema id more than once fields.append( - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value="some msg")) + DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value="some msg") + ) - task = dataset.create_data_rows([{ - DataRow.row_data: image_url, - DataRow.metadata_fields: fields - }]) + task = dataset.create_data_rows( + [{DataRow.row_data: image_url, DataRow.metadata_fields: fields}] + ) task.wait_till_done(timeout_seconds=60) assert task.status == "COMPLETE" assert len(task.failed_data_rows) == 1 - assert f"A schemaId can only be specified once per DataRow : [{TEXT_SCHEMA_ID}]" in task.failed_data_rows[ - 0]["message"] + assert ( + f"A schemaId can only be specified once per DataRow : [{TEXT_SCHEMA_ID}]" + in task.failed_data_rows[0]["message"] + ) def test_create_data_rows_with_metadata_missing_value(dataset, image_url): @@ -567,13 +590,15 @@ def test_create_data_rows_with_metadata_missing_value(dataset, image_url): fields.append({"schemaId": "some schema id"}) with pytest.raises(ValueError) as exc: - dataset.create_data_rows([ - { - DataRow.row_data: image_url, - DataRow.external_id: "row1", - DataRow.metadata_fields: fields - }, - ]) + dataset.create_data_rows( + [ + { + DataRow.row_data: image_url, + DataRow.external_id: "row1", + DataRow.metadata_fields: fields, + }, + ] + ) def test_create_data_rows_with_metadata_missing_schema_id(dataset, image_url): @@ -581,13 +606,15 @@ def test_create_data_rows_with_metadata_missing_schema_id(dataset, image_url): fields.append({"value": "some value"}) with pytest.raises(ValueError) as exc: - dataset.create_data_rows([ - { - DataRow.row_data: image_url, - DataRow.external_id: "row1", - DataRow.metadata_fields: fields - }, - ]) + dataset.create_data_rows( + [ + { + DataRow.row_data: image_url, + DataRow.external_id: "row1", + DataRow.metadata_fields: fields, + }, + ] + ) def test_create_data_rows_with_metadata_wrong_type(dataset, image_url): @@ -595,20 +622,24 @@ def test_create_data_rows_with_metadata_wrong_type(dataset, image_url): fields.append("Neither DataRowMetadataField or dict") with pytest.raises(ValueError) as exc: - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url, - DataRow.external_id: "row1", - DataRow.metadata_fields: fields - }, - ]) + task = dataset.create_data_rows( + [ + { + DataRow.row_data: image_url, + DataRow.external_id: "row1", + DataRow.metadata_fields: fields, + }, + ] + ) def test_data_row_update_missing_or_empty_required_fields( - dataset, rand_gen, image_url): + dataset, rand_gen, image_url +): external_id = rand_gen(str) - data_row = dataset.create_data_row(row_data=image_url, - external_id=external_id) + data_row = dataset.create_data_row( + row_data=image_url, external_id=external_id + ) with pytest.raises(ValueError): data_row.update(row_data="") with pytest.raises(ValueError): @@ -621,11 +652,13 @@ def test_data_row_update_missing_or_empty_required_fields( data_row.update() -def test_data_row_update(client, dataset, rand_gen, image_url, - wait_for_data_row_processing): +def test_data_row_update( + client, dataset, rand_gen, image_url, wait_for_data_row_processing +): external_id = rand_gen(str) - data_row = dataset.create_data_row(row_data=image_url, - external_id=external_id) + data_row = dataset.create_data_row( + row_data=image_url, external_id=external_id + ) assert data_row.external_id == external_id external_id_2 = rand_gen(str) @@ -643,25 +676,23 @@ def test_data_row_update(client, dataset, rand_gen, image_url, # tileLayer becomes a media attribute pdf_url = "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf" tileLayerUrl = "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483-lb-textlayer.json" - data_row.update(row_data={'pdfUrl': pdf_url, "tileLayerUrl": tileLayerUrl}) - custom_check = lambda data_row: data_row.row_data and 'pdfUrl' not in data_row.row_data - data_row = wait_for_data_row_processing(client, - data_row, - custom_check=custom_check) + data_row.update(row_data={"pdfUrl": pdf_url, "tileLayerUrl": tileLayerUrl}) + custom_check = ( + lambda data_row: data_row.row_data and "pdfUrl" not in data_row.row_data + ) + data_row = wait_for_data_row_processing( + client, data_row, custom_check=custom_check + ) assert data_row.row_data == pdf_url def test_data_row_filtering_sorting(dataset, image_url): - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url, - DataRow.external_id: "row1" - }, - { - DataRow.row_data: image_url, - DataRow.external_id: "row2" - }, - ]) + task = dataset.create_data_rows( + [ + {DataRow.row_data: image_url, DataRow.external_id: "row1"}, + {DataRow.row_data: image_url, DataRow.external_id: "row2"}, + ] + ) task.wait_till_done() # Test filtering @@ -681,10 +712,12 @@ def test_data_row_filtering_sorting(dataset, image_url): @pytest.fixture def create_datarows_for_data_row_deletion(dataset, image_url): - task = dataset.create_data_rows([{ - DataRow.row_data: image_url, - DataRow.external_id: str(i) - } for i in range(10)]) + task = dataset.create_data_rows( + [ + {DataRow.row_data: image_url, DataRow.external_id: str(i)} + for i in range(10) + ] + ) task.wait_till_done() data_rows = list(dataset.data_rows()) @@ -716,34 +749,39 @@ def test_data_row_deletion(dataset, create_datarows_for_data_row_deletion): def test_data_row_iteration(dataset, image_url) -> None: - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url - }, - { - "row_data": image_url - }, - ]) + task = dataset.create_data_rows( + [ + {DataRow.row_data: image_url}, + {"row_data": image_url}, + ] + ) task.wait_till_done() assert next(dataset.data_rows()) def test_data_row_attachments(dataset, image_url): - attachments = [("IMAGE", image_url, "attachment image"), - ("RAW_TEXT", "test-text", None), - ("IMAGE_OVERLAY", image_url, "Overlay"), - ("HTML", image_url, None)] - task = dataset.create_data_rows([{ - "row_data": - image_url, - "external_id": - "test-id", - "attachments": [{ - "type": attachment_type, - "value": attachment_value, - "name": attachment_name - }] - } for attachment_type, attachment_value, attachment_name in attachments]) + attachments = [ + ("IMAGE", image_url, "attachment image"), + ("RAW_TEXT", "test-text", None), + ("IMAGE_OVERLAY", image_url, "Overlay"), + ("HTML", image_url, None), + ] + task = dataset.create_data_rows( + [ + { + "row_data": image_url, + "external_id": "test-id", + "attachments": [ + { + "type": attachment_type, + "value": attachment_value, + "name": attachment_name, + } + ], + } + for attachment_type, attachment_value, attachment_name in attachments + ] + ) task.wait_till_done() assert task.status == "COMPLETE" @@ -754,33 +792,42 @@ def test_data_row_attachments(dataset, image_url): assert data_row.external_id == "test-id" with pytest.raises(ValueError) as exc: - task = dataset.create_data_rows([{ - "row_data": image_url, - "external_id": "test-id", - "attachments": [{ - "type": "INVALID", - "value": "123" - }] - }]) + task = dataset.create_data_rows( + [ + { + "row_data": image_url, + "external_id": "test-id", + "attachments": [{"type": "INVALID", "value": "123"}], + } + ] + ) def test_create_data_rows_sync_attachments(dataset, image_url): - attachments = [("IMAGE", image_url, "image URL"), - ("RAW_TEXT", "test-text", None), - ("IMAGE_OVERLAY", image_url, "Overlay"), - ("HTML", image_url, None)] + attachments = [ + ("IMAGE", image_url, "image URL"), + ("RAW_TEXT", "test-text", None), + ("IMAGE_OVERLAY", image_url, "Overlay"), + ("HTML", image_url, None), + ] attachments_per_data_row = 3 - dataset.create_data_rows_sync([{ - "row_data": - image_url, - "external_id": - "test-id", - "attachments": [{ - "type": attachment_type, - "value": attachment_value, - "name": attachment_name - } for _ in range(attachments_per_data_row)] - } for attachment_type, attachment_value, attachment_name in attachments]) + dataset.create_data_rows_sync( + [ + { + "row_data": image_url, + "external_id": "test-id", + "attachments": [ + { + "type": attachment_type, + "value": attachment_value, + "name": attachment_name, + } + for _ in range(attachments_per_data_row) + ], + } + for attachment_type, attachment_value, attachment_name in attachments + ] + ) data_rows = list(dataset.data_rows()) assert len(data_rows) == len(attachments) for data_row in data_rows: @@ -793,15 +840,16 @@ def test_create_data_rows_sync_mixed_upload(dataset, image_url): with NamedTemporaryFile() as fp: fp.write("Test data".encode()) fp.flush() - dataset.create_data_rows_sync([{ - DataRow.row_data: image_url - }] * n_urls + [fp.name] * n_local) + dataset.create_data_rows_sync( + [{DataRow.row_data: image_url}] * n_urls + [fp.name] * n_local + ) assert len(list(dataset.data_rows())) == n_local + n_urls def test_create_data_row_attachment(data_row): - att = data_row.create_attachment("IMAGE", "https://example.com/image.jpg", - "name") + att = data_row.create_attachment( + "IMAGE", "https://example.com/image.jpg", "name" + ) assert att.attachment_type == "IMAGE" assert att.attachment_value == "https://example.com/image.jpg" assert att.attachment_name == "name" @@ -823,21 +871,30 @@ def test_delete_data_row_attachment(data_row, image_url): attachments = [] # Anonymous attachment - to_attach = [("IMAGE", image_url), ("RAW_TEXT", "test-text"), - ("IMAGE_OVERLAY", image_url), ("HTML", image_url)] + to_attach = [ + ("IMAGE", image_url), + ("RAW_TEXT", "test-text"), + ("IMAGE_OVERLAY", image_url), + ("HTML", image_url), + ] for attachment_type, attachment_value in to_attach: attachments.append( - data_row.create_attachment(attachment_type, attachment_value)) + data_row.create_attachment(attachment_type, attachment_value) + ) # Attachment with a name - to_attach = [("IMAGE", image_url, "Att. Image"), - ("RAW_TEXT", "test-text", "Att. Text"), - ("IMAGE_OVERLAY", image_url, "Image Overlay"), - ("HTML", image_url, "Att. HTML")] + to_attach = [ + ("IMAGE", image_url, "Att. Image"), + ("RAW_TEXT", "test-text", "Att. Text"), + ("IMAGE_OVERLAY", image_url, "Image Overlay"), + ("HTML", image_url, "Att. HTML"), + ] for attachment_type, attachment_value, attachment_name in to_attach: attachments.append( - data_row.create_attachment(attachment_type, attachment_value, - attachment_name)) + data_row.create_attachment( + attachment_type, attachment_value, attachment_name + ) + ) for attachment in attachments: attachment.delete() @@ -847,7 +904,8 @@ def test_delete_data_row_attachment(data_row, image_url): def test_update_data_row_attachment(data_row, image_url): attachment: AssetAttachment = data_row.create_attachment( - "RAW_TEXT", "value", "name") + "RAW_TEXT", "value", "name" + ) assert attachment is not None attachment.update(name="updated name", type="IMAGE", value=image_url) assert attachment.attachment_name == "updated name" @@ -857,7 +915,8 @@ def test_update_data_row_attachment(data_row, image_url): def test_update_data_row_attachment_invalid_type(data_row): attachment: AssetAttachment = data_row.create_attachment( - "RAW_TEXT", "value", "name") + "RAW_TEXT", "value", "name" + ) assert attachment is not None with pytest.raises(ValueError): attachment.update(name="updated name", type="INVALID", value="value") @@ -865,7 +924,8 @@ def test_update_data_row_attachment_invalid_type(data_row): def test_update_data_row_attachment_invalid_value(data_row): attachment: AssetAttachment = data_row.create_attachment( - "RAW_TEXT", "value", "name") + "RAW_TEXT", "value", "name" + ) assert attachment is not None with pytest.raises(ValueError): attachment.update(name="updated name", type="IMAGE", value="") @@ -873,7 +933,8 @@ def test_update_data_row_attachment_invalid_value(data_row): def test_does_not_update_not_provided_attachment_fields(data_row): attachment: AssetAttachment = data_row.create_attachment( - "RAW_TEXT", "value", "name") + "RAW_TEXT", "value", "name" + ) assert attachment is not None attachment.update(value=None, name="name") assert attachment.attachment_value == "value" @@ -884,27 +945,33 @@ def test_does_not_update_not_provided_attachment_fields(data_row): def test_create_data_rows_result(client, dataset, image_url): - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url, - DataRow.external_id: "row1", - }, - { - DataRow.row_data: image_url, - DataRow.external_id: "row1", - }, - ]) + task = dataset.create_data_rows( + [ + { + DataRow.row_data: image_url, + DataRow.external_id: "row1", + }, + { + DataRow.row_data: image_url, + DataRow.external_id: "row1", + }, + ] + ) task.wait_till_done() assert task.errors is None for result in task.result: - client.get_data_row(result['id']) + client.get_data_row(result["id"]) def test_create_data_rows_local_file(dataset, sample_image): - task = dataset.create_data_rows([{ - DataRow.row_data: sample_image, - DataRow.metadata_fields: make_metadata_fields() - }]) + task = dataset.create_data_rows( + [ + { + DataRow.row_data: sample_image, + DataRow.metadata_fields: make_metadata_fields(), + } + ] + ) task.wait_till_done() assert task.status == "COMPLETE" data_row = list(dataset.data_rows())[0] @@ -914,10 +981,9 @@ def test_create_data_rows_local_file(dataset, sample_image): def test_data_row_with_global_key(dataset, sample_image): global_key = str(uuid.uuid4()) - row = dataset.create_data_row({ - DataRow.row_data: sample_image, - DataRow.global_key: global_key - }) + row = dataset.create_data_row( + {DataRow.row_data: sample_image, DataRow.global_key: global_key} + ) assert row.global_key == global_key @@ -927,36 +993,32 @@ def test_data_row_bulk_creation_with_unique_global_keys(dataset, sample_image): global_key_2 = str(uuid.uuid4()) global_key_3 = str(uuid.uuid4()) - task = dataset.create_data_rows([ - { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 - }, - { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_2 - }, - { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_3 - }, - ]) + task = dataset.create_data_rows( + [ + {DataRow.row_data: sample_image, DataRow.global_key: global_key_1}, + {DataRow.row_data: sample_image, DataRow.global_key: global_key_2}, + {DataRow.row_data: sample_image, DataRow.global_key: global_key_3}, + ] + ) task.wait_till_done() - assert {row.global_key for row in dataset.data_rows() - } == {global_key_1, global_key_2, global_key_3} + assert {row.global_key for row in dataset.data_rows()} == { + global_key_1, + global_key_2, + global_key_3, + } -def test_data_row_bulk_creation_with_same_global_keys(dataset, sample_image, - snapshot): +def test_data_row_bulk_creation_with_same_global_keys( + dataset, sample_image, snapshot +): global_key_1 = str(uuid.uuid4()) - task = dataset.create_data_rows([{ - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 - }, { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 - }]) + task = dataset.create_data_rows( + [ + {DataRow.row_data: sample_image, DataRow.global_key: global_key_1}, + {DataRow.row_data: sample_image, DataRow.global_key: global_key_1}, + ] + ) task.wait_till_done() @@ -965,12 +1027,16 @@ def test_data_row_bulk_creation_with_same_global_keys(dataset, sample_image, assert len(task.failed_data_rows) == 1 assert type(task.created_data_rows) is list assert len(task.created_data_rows) == 1 - assert task.failed_data_rows[0][ - 'message'] == f"Duplicate global key: '{global_key_1}'" - assert task.failed_data_rows[0]['failedDataRows'][0][ - 'externalId'] == sample_image - assert task.created_data_rows[0]['external_id'] == sample_image - assert task.created_data_rows[0]['global_key'] == global_key_1 + assert ( + task.failed_data_rows[0]["message"] + == f"Duplicate global key: '{global_key_1}'" + ) + assert ( + task.failed_data_rows[0]["failedDataRows"][0]["externalId"] + == sample_image + ) + assert task.created_data_rows[0]["external_id"] == sample_image + assert task.created_data_rows[0]["global_key"] == global_key_1 assert len(task.errors) == 1 assert task.has_errors() is True @@ -980,11 +1046,12 @@ def test_data_row_bulk_creation_with_same_global_keys(dataset, sample_image, def test_data_row_delete_and_create_with_same_global_key( - client, dataset, sample_image): + client, dataset, sample_image +): global_key_1 = str(uuid.uuid4()) data_row_payload = { DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 + DataRow.global_key: global_key_1, } # should successfully insert new datarow @@ -992,9 +1059,9 @@ def test_data_row_delete_and_create_with_same_global_key( task.wait_till_done() assert task.status == "COMPLETE" - assert task.result[0]['global_key'] == global_key_1 + assert task.result[0]["global_key"] == global_key_1 - new_data_row_id = task.result[0]['id'] + new_data_row_id = task.result[0]["id"] # same payload should fail due to duplicated global key task = dataset.create_data_rows([data_row_payload]) @@ -1002,8 +1069,10 @@ def test_data_row_delete_and_create_with_same_global_key( assert task.status == "COMPLETE" assert len(task.failed_data_rows) == 1 - assert task.failed_data_rows[0][ - 'message'] == f"Duplicate global key: '{global_key_1}'" + assert ( + task.failed_data_rows[0]["message"] + == f"Duplicate global key: '{global_key_1}'" + ) # delete datarow client.get_data_row(new_data_row_id).delete() @@ -1013,46 +1082,49 @@ def test_data_row_delete_and_create_with_same_global_key( task.wait_till_done() assert task.status == "COMPLETE" - assert task.result[0]['global_key'] == global_key_1 + assert task.result[0]["global_key"] == global_key_1 def test_data_row_bulk_creation_sync_with_unique_global_keys( - dataset, sample_image): + dataset, sample_image +): global_key_1 = str(uuid.uuid4()) global_key_2 = str(uuid.uuid4()) global_key_3 = str(uuid.uuid4()) - dataset.create_data_rows_sync([ - { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 - }, - { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_2 - }, - { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_3 - }, - ]) + dataset.create_data_rows_sync( + [ + {DataRow.row_data: sample_image, DataRow.global_key: global_key_1}, + {DataRow.row_data: sample_image, DataRow.global_key: global_key_2}, + {DataRow.row_data: sample_image, DataRow.global_key: global_key_3}, + ] + ) - assert {row.global_key for row in dataset.data_rows() - } == {global_key_1, global_key_2, global_key_3} + assert {row.global_key for row in dataset.data_rows()} == { + global_key_1, + global_key_2, + global_key_3, + } def test_data_row_bulk_creation_sync_with_same_global_keys( - dataset, sample_image): + dataset, sample_image +): global_key_1 = str(uuid.uuid4()) with pytest.raises(ResourceCreationError) as exc_info: - dataset.create_data_rows_sync([{ - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 - }, { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 - }]) + dataset.create_data_rows_sync( + [ + { + DataRow.row_data: sample_image, + DataRow.global_key: global_key_1, + }, + { + DataRow.row_data: sample_image, + DataRow.global_key: global_key_1, + }, + ] + ) assert len(list(dataset.data_rows())) == 1 assert list(dataset.data_rows())[0].global_key == global_key_1 @@ -1064,13 +1136,13 @@ def test_data_row_bulk_creation_sync_with_same_global_keys( def conversational_data_rows(dataset, conversational_content): examples = [ { - **conversational_content, 'media_type': - MediaType.Conversational.value + **conversational_content, + "media_type": MediaType.Conversational.value, }, conversational_content, { - "conversationalData": conversational_content['row_data']['messages'] - } # Old way to check for backwards compatibility + "conversationalData": conversational_content["row_data"]["messages"] + }, # Old way to check for backwards compatibility ] task = dataset.create_data_rows(examples) task.wait_till_done() @@ -1083,49 +1155,47 @@ def conversational_data_rows(dataset, conversational_content): dr.delete() -def test_create_conversational_text(conversational_data_rows, - conversational_content): +def test_create_conversational_text( + conversational_data_rows, conversational_content +): data_rows = conversational_data_rows for data_row in data_rows: - assert json.loads( - data_row.row_data) == conversational_content['row_data'] + assert ( + json.loads(data_row.row_data) == conversational_content["row_data"] + ) def test_invalid_media_type(dataset, conversational_content): - for _, __ in [["Found invalid contents for media type: 'IMAGE'", 'IMAGE'], - [ - "Found invalid media type: 'totallyinvalid'", - 'totallyinvalid' - ]]: + for _, __ in [ + ["Found invalid contents for media type: 'IMAGE'", "IMAGE"], + ["Found invalid media type: 'totallyinvalid'", "totallyinvalid"], + ]: # TODO: What error kind should this be? It looks like for global key we are # using malformed query. But for invalid contents in FileUploads we use InvalidQueryError with pytest.raises(ResourceCreationError): - dataset.create_data_rows_sync([{ - **conversational_content, 'media_type': 'IMAGE' - }]) + dataset.create_data_rows_sync( + [{**conversational_content, "media_type": "IMAGE"}] + ) def test_create_tiled_layer(dataset, tile_content): examples = [ - { - **tile_content, 'media_type': 'TMS_GEO' - }, + {**tile_content, "media_type": "TMS_GEO"}, tile_content, ] dataset.create_data_rows_sync(examples) data_rows = list(dataset.data_rows()) assert len(data_rows) == len(examples) for data_row in data_rows: - assert json.loads(data_row.row_data) == tile_content['row_data'] + assert json.loads(data_row.row_data) == tile_content["row_data"] def test_create_data_row_with_attachments(dataset): - attachment_value = 'attachment value' - dr = dataset.create_data_row(row_data="123", - attachments=[{ - 'type': 'RAW_TEXT', - 'value': attachment_value - }]) + attachment_value = "attachment value" + dr = dataset.create_data_row( + row_data="123", + attachments=[{"type": "RAW_TEXT", "value": attachment_value}], + ) attachments = list(dr.attachments()) assert len(attachments) == 1 @@ -1133,7 +1203,8 @@ def test_create_data_row_with_attachments(dataset): def test_create_data_row_with_media_type(dataset, image_url): with pytest.raises(ResourceCreationError) as exc: dr = dataset.create_data_row( - row_data={'invalid_object': 'invalid_value'}, media_type="IMAGE") + row_data={"invalid_object": "invalid_value"}, media_type="IMAGE" + ) assert "Expected type image/*, detected: application/json" in str(exc.value) diff --git a/libs/labelbox/tests/integration/test_data_rows_upsert.py b/libs/labelbox/tests/integration/test_data_rows_upsert.py index da99eecc6..2ba7a9df9 100644 --- a/libs/labelbox/tests/integration/test_data_rows_upsert.py +++ b/libs/labelbox/tests/integration/test_data_rows_upsert.py @@ -9,87 +9,70 @@ class TestDataRowUpsert: - @pytest.fixture def all_inclusive_data_row(self, dataset, image_url): dr = dataset.create_data_row( row_data=image_url, external_id="ex1", global_key=str(uuid.uuid4()), - metadata_fields=[{ - "name": "tag", - "value": "tag_string" - }, { - "name": "split", - "value": "train" - }], + metadata_fields=[ + {"name": "tag", "value": "tag_string"}, + {"name": "split", "value": "train"}, + ], attachments=[ + {"type": "RAW_TEXT", "name": "att1", "value": "test1"}, { - "type": "RAW_TEXT", - "name": "att1", - "value": "test1" - }, - { - "type": - "IMAGE", - "name": - "att2", - "value": - "https://storage.googleapis.com/labelbox-sample-datasets/Docs/disease_attachment.jpeg" + "type": "IMAGE", + "name": "att2", + "value": "https://storage.googleapis.com/labelbox-sample-datasets/Docs/disease_attachment.jpeg", }, { - "type": - "PDF_URL", - "name": - "att3", - "value": - "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf" + "type": "PDF_URL", + "name": "att3", + "value": "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf", }, - ]) + ], + ) return dr @pytest.mark.order(1) def test_create_data_row_with_auto_key(self, dataset, image_url): - task = dataset.upsert_data_rows([{'row_data': image_url}]) + task = dataset.upsert_data_rows([{"row_data": image_url}]) task.wait_till_done() assert len(list(dataset.data_rows())) == 1 def test_create_data_row_with_upsert(self, client, dataset, image_url): gkey = str(uuid.uuid4()) - task = dataset.upsert_data_rows([{ - 'row_data': - image_url, - 'global_key': - gkey, - 'external_id': - "ex1", - 'attachments': [{ - 'type': AttachmentType.RAW_TEXT, - 'name': "att1", - 'value': "test1" - }, { - 'type': - AttachmentType.IMAGE, - 'name': - "att2", - 'value': - "https://storage.googleapis.com/labelbox-sample-datasets/Docs/disease_attachment.jpeg" - }, { - 'type': - AttachmentType.PDF_URL, - 'name': - "att3", - 'value': - "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf" - }], - 'metadata': [{ - 'name': "tag", - 'value': "updated tag" - }, { - 'name': "split", - 'value': "train" - }] - }]) + task = dataset.upsert_data_rows( + [ + { + "row_data": image_url, + "global_key": gkey, + "external_id": "ex1", + "attachments": [ + { + "type": AttachmentType.RAW_TEXT, + "name": "att1", + "value": "test1", + }, + { + "type": AttachmentType.IMAGE, + "name": "att2", + "value": "https://storage.googleapis.com/labelbox-sample-datasets/Docs/disease_attachment.jpeg", + }, + { + "type": AttachmentType.PDF_URL, + "name": "att3", + "value": "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf", + }, + ], + "metadata": [ + {"name": "tag", "value": "updated tag"}, + {"name": "split", "value": "train"}, + ], + } + ] + ) task.wait_till_done() assert task.status == "COMPLETE" dr = client.get_data_row_by_global_key(gkey) @@ -107,31 +90,40 @@ def test_create_data_row_with_upsert(self, client, dataset, image_url): assert attachments[1].attachment_name == "att2" assert attachments[1].attachment_type == AttachmentType.IMAGE - assert attachments[ - 1].attachment_value == "https://storage.googleapis.com/labelbox-sample-datasets/Docs/disease_attachment.jpeg" + assert ( + attachments[1].attachment_value + == "https://storage.googleapis.com/labelbox-sample-datasets/Docs/disease_attachment.jpeg" + ) assert attachments[2].attachment_name == "att3" assert attachments[2].attachment_type == AttachmentType.PDF_URL - assert attachments[ - 2].attachment_value == "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf" + assert ( + attachments[2].attachment_value + == "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf" + ) assert len(dr.metadata_fields) == 2 - assert dr.metadata_fields[0]['name'] == "tag" - assert dr.metadata_fields[0]['value'] == "updated tag" - assert dr.metadata_fields[1]['name'] == "split" - assert dr.metadata_fields[1]['value'] == "train" + assert dr.metadata_fields[0]["name"] == "tag" + assert dr.metadata_fields[0]["value"] == "updated tag" + assert dr.metadata_fields[1]["name"] == "split" + assert dr.metadata_fields[1]["value"] == "train" - def test_update_data_row_fields_with_upsert(self, client, dataset, - image_url): + def test_update_data_row_fields_with_upsert( + self, client, dataset, image_url + ): gkey = str(uuid.uuid4()) - dr = dataset.create_data_row(row_data=image_url, - external_id="ex1", - global_key=gkey) - task = dataset.upsert_data_rows([{ - 'key': UniqueId(dr.uid), - 'external_id': "ex1_updated", - 'global_key': f"{gkey}_updated" - }]) + dr = dataset.create_data_row( + row_data=image_url, external_id="ex1", global_key=gkey + ) + task = dataset.upsert_data_rows( + [ + { + "key": UniqueId(dr.uid), + "external_id": "ex1_updated", + "global_key": f"{gkey}_updated", + } + ] + ) task.wait_till_done() assert task.status == "COMPLETE" dr = client.get_data_row(dr.uid) @@ -140,16 +132,21 @@ def test_update_data_row_fields_with_upsert(self, client, dataset, assert dr.global_key == f"{gkey}_updated" def test_update_data_row_fields_with_upsert_by_global_key( - self, client, dataset, image_url): + self, client, dataset, image_url + ): gkey = str(uuid.uuid4()) - dr = dataset.create_data_row(row_data=image_url, - external_id="ex1", - global_key=gkey) - task = dataset.upsert_data_rows([{ - 'key': GlobalKey(dr.global_key), - 'external_id': "ex1_updated", - 'global_key': f"{gkey}_updated" - }]) + dr = dataset.create_data_row( + row_data=image_url, external_id="ex1", global_key=gkey + ) + task = dataset.upsert_data_rows( + [ + { + "key": GlobalKey(dr.global_key), + "external_id": "ex1_updated", + "global_key": f"{gkey}_updated", + } + ] + ) task.wait_till_done() assert task.status == "COMPLETE" dr = client.get_data_row(dr.uid) @@ -157,20 +154,25 @@ def test_update_data_row_fields_with_upsert_by_global_key( assert dr.external_id == "ex1_updated" assert dr.global_key == f"{gkey}_updated" - def test_update_attachments_with_upsert(self, client, - all_inclusive_data_row, dataset): + def test_update_attachments_with_upsert( + self, client, all_inclusive_data_row, dataset + ): dr = all_inclusive_data_row - task = dataset.upsert_data_rows([{ - 'key': - UniqueId(dr.uid), - 'row_data': - dr.row_data, - 'attachments': [{ - 'type': AttachmentType.RAW_TEXT, - 'name': "att1", - 'value': "test" - }] - }]) + task = dataset.upsert_data_rows( + [ + { + "key": UniqueId(dr.uid), + "row_data": dr.row_data, + "attachments": [ + { + "type": AttachmentType.RAW_TEXT, + "name": "att1", + "value": "test", + } + ], + } + ] + ) task.wait_till_done() assert task.status == "COMPLETE" dr = client.get_data_row(dr.uid) @@ -179,44 +181,49 @@ def test_update_attachments_with_upsert(self, client, assert len(attachments) == 1 assert attachments[0].attachment_name == "att1" - def test_update_metadata_with_upsert(self, client, all_inclusive_data_row, - dataset): + def test_update_metadata_with_upsert( + self, client, all_inclusive_data_row, dataset + ): dr = all_inclusive_data_row - task = dataset.upsert_data_rows([{ - 'key': - GlobalKey(dr.global_key), - 'row_data': - dr.row_data, - 'metadata': [{ - 'name': "tag", - 'value': "updated tag" - }, { - 'name': "split", - 'value': "train" - }] - }]) + task = dataset.upsert_data_rows( + [ + { + "key": GlobalKey(dr.global_key), + "row_data": dr.row_data, + "metadata": [ + {"name": "tag", "value": "updated tag"}, + {"name": "split", "value": "train"}, + ], + } + ] + ) task.wait_till_done() assert task.status == "COMPLETE" dr = client.get_data_row(dr.uid) assert dr is not None assert len(dr.metadata_fields) == 2 - assert dr.metadata_fields[0]['name'] == "tag" - assert dr.metadata_fields[0]['value'] == "updated tag" - assert dr.metadata_fields[1]['name'] == "split" - assert dr.metadata_fields[1]['value'] == "train" + assert dr.metadata_fields[0]["name"] == "tag" + assert dr.metadata_fields[0]["value"] == "updated tag" + assert dr.metadata_fields[1]["name"] == "split" + assert dr.metadata_fields[1]["value"] == "train" def test_multiple_chunks(self, client, dataset, image_url): mocked_chunk_size = 300 - with patch('labelbox.client.Client.upload_data', - wraps=client.upload_data) as spy_some_function: - with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES', - new=mocked_chunk_size): - task = dataset.upsert_data_rows([{ - 'row_data': image_url - } for i in range(10)]) + with patch( + "labelbox.client.Client.upload_data", wraps=client.upload_data + ) as spy_some_function: + with patch( + "labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES", + new=mocked_chunk_size, + ): + task = dataset.upsert_data_rows( + [{"row_data": image_url} for i in range(10)] + ) task.wait_till_done() assert len(list(dataset.data_rows())) == 10 - assert spy_some_function.call_count == 11 # one per each data row + manifest + assert ( + spy_some_function.call_count == 11 + ) # one per each data row + manifest first_call_args, _ = spy_some_function.call_args_list[0] first_chunk_content = first_call_args[0] @@ -228,23 +235,25 @@ def test_multiple_chunks(self, client, dataset, image_url): assert len(data) in {1, 3} last_call_args, _ = spy_some_function.call_args_list[-1] - manifest_content = last_call_args[0].decode('utf-8') + manifest_content = last_call_args[0].decode("utf-8") data = json.loads(manifest_content) - assert data['source'] == "SDK" - assert data['item_count'] == 10 - assert len(data['chunk_uris']) == 10 + assert data["source"] == "SDK" + assert data["item_count"] == 10 + assert len(data["chunk_uris"]) == 10 def test_upsert_embedded_row_data(self, dataset): pdf_url = "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/0801.3483.pdf" - task = dataset.upsert_data_rows([{ - 'row_data': { - "pdf_url": - pdf_url, - "text_layer_url": - "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/0801.3483-lb-textlayer.json" - }, - 'media_type': "PDF" - }]) + task = dataset.upsert_data_rows( + [ + { + "row_data": { + "pdf_url": pdf_url, + "text_layer_url": "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/0801.3483-lb-textlayer.json", + }, + "media_type": "PDF", + } + ] + ) task.wait_till_done() data_rows = list(dataset.data_rows()) assert len(data_rows) == 1 @@ -252,21 +261,17 @@ def test_upsert_embedded_row_data(self, dataset): def test_upsert_duplicate_global_key_error(self, dataset, image_url): gkey = str(uuid.uuid4()) - task = dataset.upsert_data_rows([ - { - 'row_data': image_url, - 'global_key': gkey - }, - { - 'row_data': image_url, - 'global_key': gkey - }, - ]) + task = dataset.upsert_data_rows( + [ + {"row_data": image_url, "global_key": gkey}, + {"row_data": image_url, "global_key": gkey}, + ] + ) task.wait_till_done() assert task.status == "COMPLETE" assert task.errors is not None assert len(task.errors) == 1 # one data row was created, one failed - assert f"Duplicate global key: '{gkey}'" in task.errors[0]['message'] + assert f"Duplicate global key: '{gkey}'" in task.errors[0]["message"] def test_upsert_empty_items(self, dataset): items = [{"key": GlobalKey("foo")}] diff --git a/libs/labelbox/tests/integration/test_dataset.py b/libs/labelbox/tests/integration/test_dataset.py index 51a43a09c..89210d6c9 100644 --- a/libs/labelbox/tests/integration/test_dataset.py +++ b/libs/labelbox/tests/integration/test_dataset.py @@ -4,11 +4,12 @@ from labelbox import Dataset from labelbox.exceptions import ResourceNotFoundError, ResourceCreationError -from labelbox.schema.internal.descriptor_file_creator import DescriptorFileCreator +from labelbox.schema.internal.descriptor_file_creator import ( + DescriptorFileCreator, +) def test_dataset(client, rand_gen): - # confirm dataset can be created name = rand_gen(str) dataset = client.create_dataset(name=name) @@ -76,8 +77,9 @@ def test_get_data_row_for_external_id(dataset, rand_gen, image_url): with pytest.raises(ResourceNotFoundError): data_row = dataset.data_row_for_external_id(external_id) - data_row = dataset.create_data_row(row_data=image_url, - external_id=external_id) + data_row = dataset.create_data_row( + row_data=image_url, external_id=external_id + ) found = dataset.data_row_for_external_id(external_id) assert found.uid == data_row.uid @@ -87,7 +89,8 @@ def test_get_data_row_for_external_id(dataset, rand_gen, image_url): assert len(dataset.data_rows_for_external_id(external_id)) == 2 task = dataset.create_data_rows( - [dict(row_data=image_url, external_id=external_id)]) + [dict(row_data=image_url, external_id=external_id)] + ) task.wait_till_done() assert len(dataset.data_rows_for_external_id(external_id)) == 3 @@ -102,41 +105,40 @@ def test_upload_video_file(dataset, sample_video: str) -> None: task = dataset.create_data_rows([sample_video, sample_video]) task.wait_till_done() - with open(sample_video, 'rb') as video_f: + with open(sample_video, "rb") as video_f: content_length = len(video_f.read()) for data_row in dataset.data_rows(): url = data_row.row_data response = requests.head(url, allow_redirects=True) - assert int(response.headers['Content-Length']) == content_length - assert response.headers['Content-Type'] == 'video/mp4' + assert int(response.headers["Content-Length"]) == content_length + assert response.headers["Content-Type"] == "video/mp4" def test_create_pdf(dataset): dataset.create_data_row( row_data={ - "pdfUrl": - "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-1.pdf", - "textLayerUrl": - "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-custom-text-layer.json" - }) - dataset.create_data_row(row_data={ - "pdfUrl": - "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-1.pdf", - "textLayerUrl": - "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-custom-text-layer.json" - }, - media_type="PDF") + "pdfUrl": "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-1.pdf", + "textLayerUrl": "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-custom-text-layer.json", + } + ) + dataset.create_data_row( + row_data={ + "pdfUrl": "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-1.pdf", + "textLayerUrl": "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-custom-text-layer.json", + }, + media_type="PDF", + ) with pytest.raises(ResourceCreationError): # Wrong media type - dataset.create_data_row(row_data={ - "pdfUrl": - "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-1.pdf", - "textLayerUrl": - "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-custom-text-layer.json" - }, - media_type="TEXT") + dataset.create_data_row( + row_data={ + "pdfUrl": "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-1.pdf", + "textLayerUrl": "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-custom-text-layer.json", + }, + media_type="TEXT", + ) def test_bulk_conversation(dataset, sample_bulk_conversation: list) -> None: @@ -152,17 +154,21 @@ def test_bulk_conversation(dataset, sample_bulk_conversation: list) -> None: def test_create_descriptor_file(dataset): import unittest.mock as mock + client = MagicMock() - with mock.patch.object(client, 'upload_data', - wraps=client.upload_data) as upload_data_spy: - DescriptorFileCreator(client).create_one(items=[{ - 'row_data': 'some text...' - }]) + with mock.patch.object( + client, "upload_data", wraps=client.upload_data + ) as upload_data_spy: + DescriptorFileCreator(client).create_one( + items=[{"row_data": "some text..."}] + ) upload_data_spy.assert_called() - call_args, call_kwargs = upload_data_spy.call_args_list[0][ - 0], upload_data_spy.call_args_list[0][1] + call_args, call_kwargs = ( + upload_data_spy.call_args_list[0][0], + upload_data_spy.call_args_list[0][1], + ) assert call_args == ('[{"row_data": "some text..."}]',) assert call_kwargs == { - 'content_type': 'application/json', - 'filename': 'json_import.json' + "content_type": "application/json", + "filename": "json_import.json", } diff --git a/libs/labelbox/tests/integration/test_delegated_access.py b/libs/labelbox/tests/integration/test_delegated_access.py index 1592319d2..0e6422b08 100644 --- a/libs/labelbox/tests/integration/test_delegated_access.py +++ b/libs/labelbox/tests/integration/test_delegated_access.py @@ -8,37 +8,39 @@ @pytest.mark.skip( - reason= - "Google credentials are being updated for this test, disabling till it's all sorted out" + reason="Google credentials are being updated for this test, disabling till it's all sorted out" +) +@pytest.mark.skipif( + not os.environ.get("DA_GCP_LABELBOX_API_KEY"), + reason="DA_GCP_LABELBOX_API_KEY not found", ) -@pytest.mark.skipif(not os.environ.get('DA_GCP_LABELBOX_API_KEY'), - reason="DA_GCP_LABELBOX_API_KEY not found") def test_default_integration(): """ This tests assumes the following: 1. gcp delegated access is configured to work with jtso-gcs-sdk-da-tests 2. the integration name is gcs sdk test bucket 3. This integration is the default - + Currently tests against: Org ID: cl269lvvj78b50zau34s4550z Email: jtso+gcp_sdk_tests@labelbox.com""" client = Client(api_key=os.environ.get("DA_GCP_LABELBOX_API_KEY")) ds = client.create_dataset(name="new_ds") dr = ds.create_data_row( - row_data= - "gs://jtso-gcs-sdk-da-tests/nikita-samokhin-D6QS6iv_CTY-unsplash.jpg") + row_data="gs://jtso-gcs-sdk-da-tests/nikita-samokhin-D6QS6iv_CTY-unsplash.jpg" + ) assert requests.get(dr.row_data).status_code == 200 assert ds.iam_integration().name == "gcs sdk test bucket" ds.delete() @pytest.mark.skip( - reason= - "Google credentials are being updated for this test, disabling till it's all sorted out" + reason="Google credentials are being updated for this test, disabling till it's all sorted out" +) +@pytest.mark.skipif( + not os.environ.get("DA_GCP_LABELBOX_API_KEY"), + reason="DA_GCP_LABELBOX_API_KEY not found", ) -@pytest.mark.skipif(not os.environ.get("DA_GCP_LABELBOX_API_KEY"), - reason="DA_GCP_LABELBOX_API_KEY not found") def test_non_default_integration(): """ This tests assumes the following: @@ -52,14 +54,13 @@ def test_non_default_integration(): client = Client(api_key=os.environ.get("DA_GCP_LABELBOX_API_KEY")) integrations = client.get_organization().get_iam_integrations() integration = [ - inte for inte in integrations if 'aws-da-test-bucket' in inte.name + inte for inte in integrations if "aws-da-test-bucket" in inte.name ][0] assert integration.valid ds = client.create_dataset(iam_integration=integration, name="new_ds") assert ds.iam_integration().name == "aws-da-test-bucket" dr = ds.create_data_row( - row_data= - "https://jtso-aws-da-sdk-tests.s3.us-east-2.amazonaws.com/adrian-yu-qkN4D3Rf1gw-unsplash.jpg" + row_data="https://jtso-aws-da-sdk-tests.s3.us-east-2.amazonaws.com/adrian-yu-qkN4D3Rf1gw-unsplash.jpg" ) assert requests.get(dr.row_data).status_code == 200 ds.delete() @@ -81,15 +82,16 @@ def test_no_default_integration(client): @pytest.mark.skip( - reason= - "Google credentials are being updated for this test, disabling till it's all sorted out" + reason="Google credentials are being updated for this test, disabling till it's all sorted out" +) +@pytest.mark.skipif( + not os.environ.get("DA_GCP_LABELBOX_API_KEY"), + reason="DA_GCP_LABELBOX_API_KEY not found", ) -@pytest.mark.skipif(not os.environ.get("DA_GCP_LABELBOX_API_KEY"), - reason="DA_GCP_LABELBOX_API_KEY not found") def test_add_integration_from_object(): """ This test is based on test_non_default_integration() and assumes the following: - + 1. aws delegated access is configured to work with lbox-test-bucket 2. an integration called aws is available to the org @@ -102,11 +104,14 @@ def test_add_integration_from_object(): # Prepare dataset with no integration integration = [ - integration for integration - in integrations - if 'aws-da-test-bucket' in integration.name][0] + integration + for integration in integrations + if "aws-da-test-bucket" in integration.name + ][0] - ds = client.create_dataset(iam_integration=None, name=f"integration_add_obj-{uuid.uuid4()}") + ds = client.create_dataset( + iam_integration=None, name=f"integration_add_obj-{uuid.uuid4()}" + ) # Test set integration with object new_integration = ds.add_iam_integration(integration) @@ -115,16 +120,18 @@ def test_add_integration_from_object(): # Cleaning ds.delete() + @pytest.mark.skip( - reason= - "Google credentials are being updated for this test, disabling till it's all sorted out" + reason="Google credentials are being updated for this test, disabling till it's all sorted out" +) +@pytest.mark.skipif( + not os.environ.get("DA_GCP_LABELBOX_API_KEY"), + reason="DA_GCP_LABELBOX_API_KEY not found", ) -@pytest.mark.skipif(not os.environ.get("DA_GCP_LABELBOX_API_KEY"), - reason="DA_GCP_LABELBOX_API_KEY not found") def test_add_integration_from_uid(): """ This test is based on test_non_default_integration() and assumes the following: - + 1. aws delegated access is configured to work with lbox-test-bucket 2. an integration called aws is available to the org @@ -137,34 +144,40 @@ def test_add_integration_from_uid(): # Prepare dataset with no integration integration = [ - integration for integration - in integrations - if 'aws-da-test-bucket' in integration.name][0] + integration + for integration in integrations + if "aws-da-test-bucket" in integration.name + ][0] - ds = client.create_dataset(iam_integration=None, name=f"integration_add_id-{uuid.uuid4()}") + ds = client.create_dataset( + iam_integration=None, name=f"integration_add_id-{uuid.uuid4()}" + ) # Test set integration with integration id integration_id = [ - integration.uid for integration - in integrations - if 'aws-da-test-bucket' in integration.name][0] - + integration.uid + for integration in integrations + if "aws-da-test-bucket" in integration.name + ][0] + new_integration = ds.add_iam_integration(integration_id) assert new_integration == integration # Cleaning ds.delete() + @pytest.mark.skip( - reason= - "Google credentials are being updated for this test, disabling till it's all sorted out" + reason="Google credentials are being updated for this test, disabling till it's all sorted out" +) +@pytest.mark.skipif( + not os.environ.get("DA_GCP_LABELBOX_API_KEY"), + reason="DA_GCP_LABELBOX_API_KEY not found", ) -@pytest.mark.skipif(not os.environ.get("DA_GCP_LABELBOX_API_KEY"), - reason="DA_GCP_LABELBOX_API_KEY not found") def test_integration_remove(): """ This test is based on test_non_default_integration() and assumes the following: - + 1. aws delegated access is configured to work with lbox-test-bucket 2. an integration called aws is available to the org @@ -177,15 +190,18 @@ def test_integration_remove(): # Prepare dataset with an existing integration integration = [ - integration for integration - in integrations - if 'aws-da-test-bucket' in integration.name][0] + integration + for integration in integrations + if "aws-da-test-bucket" in integration.name + ][0] - ds = client.create_dataset(iam_integration=integration, name=f"integration_remove-{uuid.uuid4()}") + ds = client.create_dataset( + iam_integration=integration, name=f"integration_remove-{uuid.uuid4()}" + ) # Test unset integration ds.remove_iam_integration() assert ds.iam_integration() is None # Cleaning - ds.delete() \ No newline at end of file + ds.delete() diff --git a/libs/labelbox/tests/integration/test_embedding.py b/libs/labelbox/tests/integration/test_embedding.py index 541b6d980..1b54ab81c 100644 --- a/libs/labelbox/tests/integration/test_embedding.py +++ b/libs/labelbox/tests/integration/test_embedding.py @@ -27,9 +27,10 @@ def test_get_embedding_by_name_not_found(client: Client): client.get_embedding_by_name("does-not-exist") -@pytest.mark.parametrize('data_rows', [10], indirect=True) -def test_import_vectors_from_file(data_rows: List[DataRow], - embedding: Embedding): +@pytest.mark.parametrize("data_rows", [10], indirect=True) +def test_import_vectors_from_file( + data_rows: List[DataRow], embedding: Embedding +): vector = [random.uniform(1.0, 2.0) for _ in range(embedding.dims)] event = threading.Event() @@ -38,10 +39,7 @@ def callback(_: Dict[str, Any]): with NamedTemporaryFile(mode="w+") as fp: lines = [ - json.dumps({ - "id": dr.uid, - "vector": vector - }) for dr in data_rows + json.dumps({"id": dr.uid, "vector": vector}) for dr in data_rows ] fp.writelines(lines) fp.flush() @@ -54,10 +52,9 @@ def test_get_imported_vector_count(dataset: Dataset, embedding: Embedding): assert embedding.get_imported_vector_count() == 0 vector = [random.uniform(1.0, 2.0) for _ in range(embedding.dims)] - dataset.create_data_row(row_data="foo", - embeddings=[{ - "embedding_id": embedding.id, - "vector": vector - }]) + dataset.create_data_row( + row_data="foo", + embeddings=[{"embedding_id": embedding.id, "vector": vector}], + ) assert embedding.get_imported_vector_count() == 1 diff --git a/libs/labelbox/tests/integration/test_ephemeral.py b/libs/labelbox/tests/integration/test_ephemeral.py index 6ebcf61c6..a23572fdf 100644 --- a/libs/labelbox/tests/integration/test_ephemeral.py +++ b/libs/labelbox/tests/integration/test_ephemeral.py @@ -2,8 +2,10 @@ import pytest -@pytest.mark.skipif(not os.environ.get('LABELBOX_TEST_ENVIRON') == 'ephemeral', - reason='This test only runs in EPHEMERAL environment') +@pytest.mark.skipif( + not os.environ.get("LABELBOX_TEST_ENVIRON") == "ephemeral", + reason="This test only runs in EPHEMERAL environment", +) def test_org_and_user_setup(client, ephmeral_client): assert type(client) == ephmeral_client assert client.admin_client @@ -15,7 +17,9 @@ def test_org_and_user_setup(client, ephmeral_client): assert user -@pytest.mark.skipif(os.environ.get('LABELBOX_TEST_ENVIRON') == 'ephemeral', - reason='This test does not run in EPHEMERAL environment') +@pytest.mark.skipif( + os.environ.get("LABELBOX_TEST_ENVIRON") == "ephemeral", + reason="This test does not run in EPHEMERAL environment", +) def test_integration_client(client, integration_client): assert type(client) == integration_client diff --git a/libs/labelbox/tests/integration/test_feature_schema.py b/libs/labelbox/tests/integration/test_feature_schema.py index 1dc25efc1..1dc940f08 100644 --- a/libs/labelbox/tests/integration/test_feature_schema.py +++ b/libs/labelbox/tests/integration/test_feature_schema.py @@ -12,36 +12,37 @@ def test_deletes_a_feature_schema(client): tool = client.upsert_feature_schema(point.asdict()) - assert client.delete_unused_feature_schema( - tool.normalized['featureSchemaId']) is None + assert ( + client.delete_unused_feature_schema(tool.normalized["featureSchemaId"]) + is None + ) def test_cant_delete_already_deleted_feature_schema(client): tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] + feature_schema_id = tool.normalized["featureSchemaId"] client.delete_unused_feature_schema(feature_schema_id) is None with pytest.raises( - Exception, - match= - "Failed to delete the feature schema, message: Feature schema is already deleted" + Exception, + match="Failed to delete the feature schema, message: Feature schema is already deleted", ): client.delete_unused_feature_schema(feature_schema_id) def test_cant_delete_feature_schema_with_ontology(client): tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] + feature_schema_id = tool.normalized["featureSchemaId"] ontology = client.create_ontology_from_feature_schemas( - name='ontology name', + name="ontology name", feature_schema_ids=[feature_schema_id], - media_type=MediaType.Image) + media_type=MediaType.Image, + ) with pytest.raises( - Exception, - match= - "Failed to delete the feature schema, message: Feature schema cannot be deleted because it is used in ontologies" + Exception, + match="Failed to delete the feature schema, message: Feature schema cannot be deleted because it is used in ontologies", ): client.delete_unused_feature_schema(feature_schema_id) @@ -51,29 +52,30 @@ def test_cant_delete_feature_schema_with_ontology(client): def test_throws_an_error_if_feature_schema_to_delete_doesnt_exist(client): with pytest.raises( - Exception, - match= - "Failed to delete the feature schema, message: Cannot find root schema node with feature schema id doesntexist" + Exception, + match="Failed to delete the feature schema, message: Cannot find root schema node with feature schema id doesntexist", ): client.delete_unused_feature_schema("doesntexist") def test_updates_a_feature_schema_title(client): tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] + feature_schema_id = tool.normalized["featureSchemaId"] new_title = "new title" updated_feature_schema = client.update_feature_schema_title( - feature_schema_id, new_title) + feature_schema_id, new_title + ) - assert updated_feature_schema.normalized['name'] == new_title + assert updated_feature_schema.normalized["name"] == new_title client.delete_unused_feature_schema(feature_schema_id) def test_throws_an_error_when_updating_a_feature_schema_with_empty_title( - client): + client, +): tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] + feature_schema_id = tool.normalized["featureSchemaId"] with pytest.raises(Exception): client.update_feature_schema_title(feature_schema_id, "") @@ -96,21 +98,23 @@ def test_updates_a_feature_schema(client, feature_schema): tool=Tool.Type.POINT, name="new name", color="#ff0000", - feature_schema_id=created_feature_schema.normalized['featureSchemaId'], + feature_schema_id=created_feature_schema.normalized["featureSchemaId"], ) updated_feature_schema = client.upsert_feature_schema( - tool_to_update.asdict()) + tool_to_update.asdict() + ) - assert updated_feature_schema.normalized['name'] == "new name" + assert updated_feature_schema.normalized["name"] == "new name" def test_does_not_include_used_feature_schema(client): tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] + feature_schema_id = tool.normalized["featureSchemaId"] ontology = client.create_ontology_from_feature_schemas( - name='ontology name', + name="ontology name", feature_schema_ids=[feature_schema_id], - media_type=MediaType.Image) + media_type=MediaType.Image, + ) unused_feature_schemas = client.get_unused_feature_schemas() assert feature_schema_id not in unused_feature_schemas diff --git a/libs/labelbox/tests/integration/test_filtering.py b/libs/labelbox/tests/integration/test_filtering.py index e751213cc..2e09ba573 100644 --- a/libs/labelbox/tests/integration/test_filtering.py +++ b/libs/labelbox/tests/integration/test_filtering.py @@ -30,8 +30,9 @@ def test_where(client, project_to_test_where): p_b_name = p_b.name def get(where=None): - date_where = Project.created_at >= min(p_a.created_at, p_b.created_at, - p_c.created_at) + date_where = Project.created_at >= min( + p_a.created_at, p_b.created_at, p_c.created_at + ) where = date_where if where is None else where & date_where return {p.uid for p in client.get_projects(where)} @@ -47,14 +48,16 @@ def get(where=None): ge_b = get(Project.name >= p_b_name) assert {p_b.uid, p_c.uid}.issubset(ge_b) and p_a.uid not in ge_b + def test_unsupported_where(client): with pytest.raises(InvalidQueryError): client.get_projects(where=(Project.name == "a") & (Project.name == "b")) # TODO support logical OR and NOT in where with pytest.raises(InvalidQueryError): - client.get_projects(where=(Project.name == "a") | - (Project.description == "b")) + client.get_projects( + where=(Project.name == "a") | (Project.description == "b") + ) with pytest.raises(InvalidQueryError): client.get_projects(where=~(Project.name == "a")) diff --git a/libs/labelbox/tests/integration/test_foundry.py b/libs/labelbox/tests/integration/test_foundry.py index 10d6be85b..83c4effc5 100644 --- a/libs/labelbox/tests/integration/test_foundry.py +++ b/libs/labelbox/tests/integration/test_foundry.py @@ -21,14 +21,15 @@ def foundry_client(client): @pytest.fixture() def text_data_row(dataset, random_str): global_key = "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-1.txt-{random_str}" - task = dataset.create_data_rows([{ - "row_data": - "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-1.txt", - "media_type": - "TEXT", - "global_key": - global_key - }]) + task = dataset.create_data_rows( + [ + { + "row_data": "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-1.txt", + "media_type": "TEXT", + "global_key": global_key, + } + ] + ) task.wait_till_done() dr = dataset.data_rows().get_one() yield dr @@ -38,32 +39,40 @@ def text_data_row(dataset, random_str): @pytest.fixture() def ontology(client, random_str): object_features = [ - lb.Tool(tool=lb.Tool.Type.BBOX, - name="text", - color="#ff0000", - classifications=[ - lb.Classification(class_type=lb.Classification.Type.TEXT, - name="value") - ]) + lb.Tool( + tool=lb.Tool.Type.BBOX, + name="text", + color="#ff0000", + classifications=[ + lb.Classification( + class_type=lb.Classification.Type.TEXT, name="value" + ) + ], + ) ] - ontology_builder = lb.OntologyBuilder(tools=object_features,) + ontology_builder = lb.OntologyBuilder( + tools=object_features, + ) ontology = client.create_ontology( f"Test ontology for tesseract model {random_str}", ontology_builder.asdict(), - media_type=lb.MediaType.Image) + media_type=lb.MediaType.Image, + ) return ontology @pytest.fixture() def unsaved_app(random_str, ontology): - return App(model_id=TEST_MODEL_ID, - name=f"Test App {random_str}", - description="Test App Description", - inference_params={"confidence": 0.2}, - class_to_schema_id={}, - ontology_id=ontology.uid) + return App( + model_id=TEST_MODEL_ID, + name=f"Test App {random_str}", + description="Test App Description", + inference_params={"confidence": 0.2}, + class_to_schema_id={}, + ontology_id=ontology.uid, + ) @pytest.fixture() @@ -75,15 +84,15 @@ def app(foundry_client, unsaved_app): def test_create_app(foundry_client, unsaved_app): app = foundry_client._create_app(unsaved_app) - retrieved_dict = app.model_dump(exclude={'id', 'created_by'}) - expected_dict = app.model_dump(exclude={'id', 'created_by'}) + retrieved_dict = app.model_dump(exclude={"id", "created_by"}) + expected_dict = app.model_dump(exclude={"id", "created_by"}) assert retrieved_dict == expected_dict def test_get_app(foundry_client, app): retrieved_app = foundry_client._get_app(app.id) - retrieved_dict = retrieved_app.model_dump(exclude={'created_by'}) - expected_dict = app.model_dump(exclude={'created_by'}) + retrieved_dict = retrieved_app.model_dump(exclude={"created_by"}) + expected_dict = app.model_dump(exclude={"created_by"}) assert retrieved_dict == expected_dict @@ -92,57 +101,65 @@ def test_get_app_with_invalid_id(foundry_client): foundry_client._get_app("invalid-id") -def test_run_foundry_app_with_data_row_id(foundry_client, data_row, app, - random_str): +def test_run_foundry_app_with_data_row_id( + foundry_client, data_row, app, random_str +): data_rows = lb.DataRowIds([data_row.uid]) task = foundry_client.run_app( model_run_name=f"test-app-with-datarow-id-{random_str}", data_rows=data_rows, - app_id=app.id) + app_id=app.id, + ) task.wait_till_done() - assert task.status == 'COMPLETE' + assert task.status == "COMPLETE" -def test_run_foundry_app_with_global_key(foundry_client, data_row, app, - random_str): +def test_run_foundry_app_with_global_key( + foundry_client, data_row, app, random_str +): data_rows = lb.GlobalKeys([data_row.global_key]) task = foundry_client.run_app( model_run_name=f"test-app-with-global-key-{random_str}", data_rows=data_rows, - app_id=app.id) + app_id=app.id, + ) task.wait_till_done() - assert task.status == 'COMPLETE' + assert task.status == "COMPLETE" -def test_run_foundry_app_returns_model_run_id(foundry_client, data_row, app, - random_str): +def test_run_foundry_app_returns_model_run_id( + foundry_client, data_row, app, random_str +): data_rows = lb.GlobalKeys([data_row.global_key]) task = foundry_client.run_app( model_run_name=f"test-app-with-global-key-{random_str}", data_rows=data_rows, - app_id=app.id) - model_run_id = task.metadata['modelRunId'] + app_id=app.id, + ) + model_run_id = task.metadata["modelRunId"] model_run = foundry_client.client.get_model_run(model_run_id) assert model_run.uid == model_run_id def test_run_foundry_with_invalid_data_row_id(foundry_client, app, random_str): - invalid_datarow_id = 'invalid-global-key' + invalid_datarow_id = "invalid-global-key" data_rows = lb.GlobalKeys([invalid_datarow_id]) with pytest.raises(lb.exceptions.LabelboxError) as exception: foundry_client.run_app( model_run_name=f"test-app-with-invalid-datarow-id-{random_str}", data_rows=data_rows, - app_id=app.id) + app_id=app.id, + ) assert invalid_datarow_id in exception.value def test_run_foundry_with_invalid_global_key(foundry_client, app, random_str): - invalid_global_key = 'invalid-global-key' + invalid_global_key = "invalid-global-key" data_rows = lb.GlobalKeys([invalid_global_key]) with pytest.raises(lb.exceptions.LabelboxError) as exception: foundry_client.run_app( model_run_name=f"test-app-with-invalid-global-key-{random_str}", data_rows=data_rows, - app_id=app.id) + app_id=app.id, + ) assert invalid_global_key in exception.value diff --git a/libs/labelbox/tests/integration/test_global_keys.py b/libs/labelbox/tests/integration/test_global_keys.py index 3fd3d84d9..9dc357812 100644 --- a/libs/labelbox/tests/integration/test_global_keys.py +++ b/libs/labelbox/tests/integration/test_global_keys.py @@ -14,38 +14,29 @@ def test_assign_global_keys_to_data_rows(client, dataset, image_url): gk_1 = str(uuid.uuid4()) gk_2 = str(uuid.uuid4()) - assignment_inputs = [{ - "data_row_id": dr_1.uid, - "global_key": gk_1 - }, { - "data_row_id": dr_2.uid, - "global_key": gk_2 - }] + assignment_inputs = [ + {"data_row_id": dr_1.uid, "global_key": gk_1}, + {"data_row_id": dr_2.uid, "global_key": gk_2}, + ] res = client.assign_global_keys_to_data_rows(assignment_inputs) - assert res['status'] == "SUCCESS" - assert res['errors'] == [] + assert res["status"] == "SUCCESS" + assert res["errors"] == [] - assert len(res['results']) == 2 - for r in res['results']: - del r['sanitized'] - assert res['results'] == assignment_inputs + assert len(res["results"]) == 2 + for r in res["results"]: + del r["sanitized"] + assert res["results"] == assignment_inputs def test_assign_global_keys_to_data_rows_validation_error(client): - assignment_inputs = [{ - "data_row_id": "test uid", - "wrong_key": "gk 1" - }, { - "data_row_id": "test uid 2", - "global_key": "gk 2" - }, { - "wrong_key": "test uid 3", - "global_key": "gk 3" - }, { - "data_row_id": "test uid 4" - }, { - "global_key": "gk 5" - }, {}] + assignment_inputs = [ + {"data_row_id": "test uid", "wrong_key": "gk 1"}, + {"data_row_id": "test uid 2", "global_key": "gk 2"}, + {"wrong_key": "test uid 3", "global_key": "gk 3"}, + {"data_row_id": "test uid 4"}, + {"global_key": "gk 5"}, + {}, + ] with pytest.raises(ValueError) as excinfo: client.assign_global_keys_to_data_rows(assignment_inputs) e = """[{'data_row_id': 'test uid', 'wrong_key': 'gk 1'}, {'wrong_key': 'test uid 3', 'global_key': 'gk 3'}, {'data_row_id': 'test uid 4'}, {'global_key': 'gk 5'}, {}]""" @@ -58,124 +49,123 @@ def test_assign_same_global_keys_to_data_rows(client, dataset, image_url): gk_1 = str(uuid.uuid4()) - assignment_inputs = [{ - "data_row_id": dr_1.uid, - "global_key": gk_1 - }, { - "data_row_id": dr_2.uid, - "global_key": gk_1 - }] + assignment_inputs = [ + {"data_row_id": dr_1.uid, "global_key": gk_1}, + {"data_row_id": dr_2.uid, "global_key": gk_1}, + ] res = client.assign_global_keys_to_data_rows(assignment_inputs) - assert res['status'] == "PARTIAL SUCCESS" - assert len(res['results']) == 1 - assert res['results'][0]['data_row_id'] == dr_1.uid - assert res['results'][0]['global_key'] == gk_1 + assert res["status"] == "PARTIAL SUCCESS" + assert len(res["results"]) == 1 + assert res["results"][0]["data_row_id"] == dr_1.uid + assert res["results"][0]["global_key"] == gk_1 - assert len(res['errors']) == 1 - assert res['errors'][0]['data_row_id'] == dr_2.uid - assert res['errors'][0]['global_key'] == gk_1 - assert res['errors'][0][ - 'error'] == "Invalid assignment. Either DataRow does not exist, or globalKey is invalid" + assert len(res["errors"]) == 1 + assert res["errors"][0]["data_row_id"] == dr_2.uid + assert res["errors"][0]["global_key"] == gk_1 + assert ( + res["errors"][0]["error"] + == "Invalid assignment. Either DataRow does not exist, or globalKey is invalid" + ) def test_long_global_key_validation(client, dataset, image_url): - long_global_key = 'x' * 201 + long_global_key = "x" * 201 dr_1 = dataset.create_data_row(row_data=image_url) dr_2 = dataset.create_data_row(row_data=image_url) gk_1 = str(uuid.uuid4()) gk_2 = long_global_key - assignment_inputs = [{ - "data_row_id": dr_1.uid, - "global_key": gk_1 - }, { - "data_row_id": dr_2.uid, - "global_key": gk_2 - }] + assignment_inputs = [ + {"data_row_id": dr_1.uid, "global_key": gk_1}, + {"data_row_id": dr_2.uid, "global_key": gk_2}, + ] res = client.assign_global_keys_to_data_rows(assignment_inputs) - assert len(res['results']) == 1 - assert len(res['errors']) == 1 - assert res['status'] == 'PARTIAL SUCCESS' - assert res['results'][0]['data_row_id'] == dr_1.uid - assert res['results'][0]['global_key'] == gk_1 - assert res['errors'][0]['data_row_id'] == dr_2.uid - assert res['errors'][0]['global_key'] == gk_2 - assert res['errors'][0][ - 'error'] == 'Invalid assignment. Either DataRow does not exist, or globalKey is invalid' + assert len(res["results"]) == 1 + assert len(res["errors"]) == 1 + assert res["status"] == "PARTIAL SUCCESS" + assert res["results"][0]["data_row_id"] == dr_1.uid + assert res["results"][0]["global_key"] == gk_1 + assert res["errors"][0]["data_row_id"] == dr_2.uid + assert res["errors"][0]["global_key"] == gk_2 + assert ( + res["errors"][0]["error"] + == "Invalid assignment. Either DataRow does not exist, or globalKey is invalid" + ) def test_global_key_with_whitespaces_validation(client, dataset, image_url): - data_row_items = [{ - "row_data": image_url, - }, { - "row_data": image_url, - }, { - "row_data": image_url, - }] + data_row_items = [ + { + "row_data": image_url, + }, + { + "row_data": image_url, + }, + { + "row_data": image_url, + }, + ] task = dataset.create_data_rows(data_row_items) task.wait_till_done() assert task.status == "COMPLETE" - dr_1_uid, dr_2_uid, dr_3_uid = [t['id'] for t in task.result] - - gk_1 = ' global key' - gk_2 = 'global key' - gk_3 = 'global key ' - - assignment_inputs = [{ - "data_row_id": dr_1_uid, - "global_key": gk_1 - }, { - "data_row_id": dr_2_uid, - "global_key": gk_2 - }, { - "data_row_id": dr_3_uid, - "global_key": gk_3 - }] + dr_1_uid, dr_2_uid, dr_3_uid = [t["id"] for t in task.result] + + gk_1 = " global key" + gk_2 = "global key" + gk_3 = "global key " + + assignment_inputs = [ + {"data_row_id": dr_1_uid, "global_key": gk_1}, + {"data_row_id": dr_2_uid, "global_key": gk_2}, + {"data_row_id": dr_3_uid, "global_key": gk_3}, + ] res = client.assign_global_keys_to_data_rows(assignment_inputs) - assert len(res['results']) == 0 - assert len(res['errors']) == 3 - assert res['status'] == 'FAILURE' - assign_errors_ids = set([e['data_row_id'] for e in res['errors']]) - assign_errors_gks = set([e['global_key'] for e in res['errors']]) - assign_errors_msgs = set([e['error'] for e in res['errors']]) + assert len(res["results"]) == 0 + assert len(res["errors"]) == 3 + assert res["status"] == "FAILURE" + assign_errors_ids = set([e["data_row_id"] for e in res["errors"]]) + assign_errors_gks = set([e["global_key"] for e in res["errors"]]) + assign_errors_msgs = set([e["error"] for e in res["errors"]]) assert assign_errors_ids == set([dr_1_uid, dr_2_uid, dr_3_uid]) assert assign_errors_gks == set([gk_1, gk_2, gk_3]) - assert assign_errors_msgs == set([ - 'Invalid assignment. Either DataRow does not exist, or globalKey is invalid', - 'Invalid assignment. Either DataRow does not exist, or globalKey is invalid', - 'Invalid assignment. Either DataRow does not exist, or globalKey is invalid' - ]) + assert assign_errors_msgs == set( + [ + "Invalid assignment. Either DataRow does not exist, or globalKey is invalid", + "Invalid assignment. Either DataRow does not exist, or globalKey is invalid", + "Invalid assignment. Either DataRow does not exist, or globalKey is invalid", + ] + ) def test_get_data_row_ids_for_global_keys(client, dataset, image_url): gk_1 = str(uuid.uuid4()) gk_2 = str(uuid.uuid4()) - dr_1 = dataset.create_data_row(row_data=image_url, - external_id="hello", - global_key=gk_1) - dr_2 = dataset.create_data_row(row_data=image_url, - external_id="world", - global_key=gk_2) + dr_1 = dataset.create_data_row( + row_data=image_url, external_id="hello", global_key=gk_1 + ) + dr_2 = dataset.create_data_row( + row_data=image_url, external_id="world", global_key=gk_2 + ) res = client.get_data_row_ids_for_global_keys([gk_1]) - assert res['status'] == "SUCCESS" - assert res['errors'] == [] - assert res['results'] == [dr_1.uid] + assert res["status"] == "SUCCESS" + assert res["errors"] == [] + assert res["results"] == [dr_1.uid] res = client.get_data_row_ids_for_global_keys([gk_2]) - assert res['status'] == "SUCCESS" - assert res['errors'] == [] - assert res['results'] == [dr_2.uid] + assert res["status"] == "SUCCESS" + assert res["errors"] == [] + assert res["results"] == [dr_2.uid] res = client.get_data_row_ids_for_global_keys([gk_1, gk_2]) - assert res['status'] == "SUCCESS" - assert res['errors'] == [] - assert res['results'] == [dr_1.uid, dr_2.uid] + assert res["status"] == "SUCCESS" + assert res["errors"] == [] + assert res["results"] == [dr_1.uid, dr_2.uid] def test_get_data_row_ids_for_invalid_global_keys(client, dataset, image_url): @@ -183,24 +173,24 @@ def test_get_data_row_ids_for_invalid_global_keys(client, dataset, image_url): gk_2 = str(uuid.uuid4()) dr_1 = dataset.create_data_row(row_data=image_url, external_id="hello") - dr_2 = dataset.create_data_row(row_data=image_url, - external_id="world", - global_key=gk_2) + dr_2 = dataset.create_data_row( + row_data=image_url, external_id="world", global_key=gk_2 + ) res = client.get_data_row_ids_for_global_keys([gk_1]) - assert res['status'] == "FAILURE" - assert len(res['errors']) == 1 - assert res['errors'][0]['error'] == "Data Row not found" - assert res['errors'][0]['global_key'] == gk_1 + assert res["status"] == "FAILURE" + assert len(res["errors"]) == 1 + assert res["errors"][0]["error"] == "Data Row not found" + assert res["errors"][0]["global_key"] == gk_1 res = client.get_data_row_ids_for_global_keys([gk_1, gk_2]) - assert res['status'] == "PARTIAL SUCCESS" + assert res["status"] == "PARTIAL SUCCESS" - assert len(res['errors']) == 1 - assert len(res['results']) == 2 + assert len(res["errors"]) == 1 + assert len(res["results"]) == 2 - assert res['errors'][0]['error'] == "Data Row not found" - assert res['errors'][0]['global_key'] == gk_1 + assert res["errors"][0]["error"] == "Data Row not found" + assert res["errors"][0]["global_key"] == gk_1 - assert res['results'][0] == '' - assert res['results'][1] == dr_2.uid + assert res["results"][0] == "" + assert res["results"][1] == dr_2.uid diff --git a/libs/labelbox/tests/integration/test_label.py b/libs/labelbox/tests/integration/test_label.py index c7221553e..1bd8a8276 100644 --- a/libs/labelbox/tests/integration/test_label.py +++ b/libs/labelbox/tests/integration/test_label.py @@ -29,10 +29,10 @@ def test_labels(configured_project_with_label): # TODO: Skipping this test in staging due to label not updating @pytest.mark.skipif( - condition=os.environ["LABELBOX_TEST_ENVIRON"] == "onprem" or - os.environ["LABELBOX_TEST_ENVIRON"] == "staging" or - os.environ["LABELBOX_TEST_ENVIRON"] == "local" or - os.environ["LABELBOX_TEST_ENVIRON"] == "custom", + condition=os.environ["LABELBOX_TEST_ENVIRON"] == "onprem" + or os.environ["LABELBOX_TEST_ENVIRON"] == "staging" + or os.environ["LABELBOX_TEST_ENVIRON"] == "local" + or os.environ["LABELBOX_TEST_ENVIRON"] == "custom", reason="does not work for onprem", ) def test_label_update(configured_project_with_label): @@ -82,8 +82,10 @@ def test_upsert_label_scores(configured_project_with_label, client: Client): label = next(project.labels()) - scores = client.upsert_label_feedback(label_id=label.uid, - feedback="That's a great label!", - scores={"overall": 5}) + scores = client.upsert_label_feedback( + label_id=label.uid, + feedback="That's a great label!", + scores={"overall": 5}, + ) assert len(scores) == 1 assert scores[0].score == 5 diff --git a/libs/labelbox/tests/integration/test_labeling_dashboard.py b/libs/labelbox/tests/integration/test_labeling_dashboard.py index 96d6af57f..97536e337 100644 --- a/libs/labelbox/tests/integration/test_labeling_dashboard.py +++ b/libs/labelbox/tests/integration/test_labeling_dashboard.py @@ -1,5 +1,21 @@ from datetime import datetime, timedelta -from labelbox.schema.search_filters import IntegerValue, RangeDateTimeOperatorWithSingleValue, RangeOperatorWithSingleValue, DateRange, RangeOperatorWithValue, DateRangeValue, DateValue, IdOperator, OperationType, OrganizationFilter, TaskCompletedCountFilter, WorkforceRequestedDateFilter, WorkforceRequestedDateRangeFilter, WorkspaceFilter, TaskRemainingCountFilter +from labelbox.schema.search_filters import ( + IntegerValue, + RangeDateTimeOperatorWithSingleValue, + RangeOperatorWithSingleValue, + DateRange, + RangeOperatorWithValue, + DateRangeValue, + DateValue, + IdOperator, + OperationType, + OrganizationFilter, + TaskCompletedCountFilter, + WorkforceRequestedDateFilter, + WorkforceRequestedDateRangeFilter, + WorkspaceFilter, + TaskRemainingCountFilter, +) def test_request_labeling_service_dashboard(requested_labeling_service): @@ -20,12 +36,14 @@ def test_request_labeling_service_dashboard_filters(requested_labeling_service): project, _ = requested_labeling_service organization = project.client.get_organization() - org_filter = OrganizationFilter(operator=IdOperator.Is, - values=[organization.uid]) + org_filter = OrganizationFilter( + operator=IdOperator.Is, values=[organization.uid] + ) try: project.client.get_labeling_service_dashboards( - search_query=[org_filter]).get_one() + search_query=[org_filter] + ).get_one() except Exception as e: assert False, f"An exception was raised: {e}" @@ -33,41 +51,55 @@ def test_request_labeling_service_dashboard_filters(requested_labeling_service): operation=OperationType.WorforceRequestedDate, value=DateValue( operator=RangeDateTimeOperatorWithSingleValue.GreaterThanOrEqual, - value=datetime.strptime("2024-01-01", "%Y-%m-%d"))) - year_from_now = (datetime.now() + timedelta(days=365)) + value=datetime.strptime("2024-01-01", "%Y-%m-%d"), + ), + ) + year_from_now = datetime.now() + timedelta(days=365) workforce_requested_filter_before = WorkforceRequestedDateFilter( operation=OperationType.WorforceRequestedDate, value=DateValue( operator=RangeDateTimeOperatorWithSingleValue.LessThanOrEqual, - value=year_from_now)) + value=year_from_now, + ), + ) try: - project.client.get_labeling_service_dashboards(search_query=[ - workforce_requested_filter_after, workforce_requested_filter_before - ]).get_one() + project.client.get_labeling_service_dashboards( + search_query=[ + workforce_requested_filter_after, + workforce_requested_filter_before, + ] + ).get_one() except Exception as e: assert False, f"An exception was raised: {e}" workforce_date_range_filter = WorkforceRequestedDateRangeFilter( operation=OperationType.WorforceRequestedDate, - value=DateRangeValue(operator=RangeOperatorWithValue.Between, - value=DateRange(min="2024-01-01T00:00:00-0800", - max=year_from_now))) + value=DateRangeValue( + operator=RangeOperatorWithValue.Between, + value=DateRange(min="2024-01-01T00:00:00-0800", max=year_from_now), + ), + ) try: project.client.get_labeling_service_dashboards( - search_query=[workforce_date_range_filter]).get_one() + search_query=[workforce_date_range_filter] + ).get_one() except Exception as e: assert False, f"An exception was raised: {e}" # with non existing data workspace_id = "clzzu4rme000008l42vnl4kre" - workspace_filter = WorkspaceFilter(operation=OperationType.Workspace, - operator=IdOperator.Is, - values=[workspace_id]) + workspace_filter = WorkspaceFilter( + operation=OperationType.Workspace, + operator=IdOperator.Is, + values=[workspace_id], + ) labeling_service_dashboard = [ - ld for ld in project.client.get_labeling_service_dashboards( - search_query=[workspace_filter]) + ld + for ld in project.client.get_labeling_service_dashboards( + search_query=[workspace_filter] + ) ] assert len(labeling_service_dashboard) == 0 assert labeling_service_dashboard == [] @@ -75,15 +107,19 @@ def test_request_labeling_service_dashboard_filters(requested_labeling_service): task_done_count_filter = TaskCompletedCountFilter( operation=OperationType.TaskCompletedCount, value=IntegerValue( - operator=RangeOperatorWithSingleValue.GreaterThanOrEqual, value=0)) + operator=RangeOperatorWithSingleValue.GreaterThanOrEqual, value=0 + ), + ) task_remaining_count_filter = TaskRemainingCountFilter( operation=OperationType.TaskRemainingCount, value=IntegerValue( - operator=RangeOperatorWithSingleValue.GreaterThanOrEqual, value=0)) + operator=RangeOperatorWithSingleValue.GreaterThanOrEqual, value=0 + ), + ) try: project.client.get_labeling_service_dashboards( - search_query=[task_done_count_filter, task_remaining_count_filter - ]).get_one() + search_query=[task_done_count_filter, task_remaining_count_filter] + ).get_one() except Exception as e: assert False, f"An exception was raised: {e}" diff --git a/libs/labelbox/tests/integration/test_labeling_frontend.py b/libs/labelbox/tests/integration/test_labeling_frontend.py index d13871372..d6ea1aac9 100644 --- a/libs/labelbox/tests/integration/test_labeling_frontend.py +++ b/libs/labelbox/tests/integration/test_labeling_frontend.py @@ -6,14 +6,16 @@ def test_get_labeling_frontends(client): filtered_frontends = list( - client.get_labeling_frontends(where=LabelingFrontend.name == 'Editor')) + client.get_labeling_frontends(where=LabelingFrontend.name == "Editor") + ) assert len(filtered_frontends) def test_labeling_frontend_connecting_to_project(project): client = project.client default_labeling_frontend = next( - client.get_labeling_frontends(where=LabelingFrontend.name == "Editor")) + client.get_labeling_frontends(where=LabelingFrontend.name == "Editor") + ) assert project.labeling_frontend() is None diff --git a/libs/labelbox/tests/integration/test_labeling_parameter_overrides.py b/libs/labelbox/tests/integration/test_labeling_parameter_overrides.py index 51c56353c..bd14040de 100644 --- a/libs/labelbox/tests/integration/test_labeling_parameter_overrides.py +++ b/libs/labelbox/tests/integration/test_labeling_parameter_overrides.py @@ -8,13 +8,20 @@ def test_labeling_parameter_overrides(consensus_project_with_batch): [project, _, data_rows] = consensus_project_with_batch init_labeling_parameter_overrides = list( - project.labeling_parameter_overrides()) + project.labeling_parameter_overrides() + ) assert len(init_labeling_parameter_overrides) == 3 - assert {o.number_of_labels for o in init_labeling_parameter_overrides - } == {1, 1, 1} + assert {o.number_of_labels for o in init_labeling_parameter_overrides} == { + 1, + 1, + 1, + } assert {o.priority for o in init_labeling_parameter_overrides} == {5, 5, 5} - assert {o.data_row().uid for o in init_labeling_parameter_overrides - } == {data_rows[0].uid, data_rows[1].uid, data_rows[2].uid} + assert {o.data_row().uid for o in init_labeling_parameter_overrides} == { + data_rows[0].uid, + data_rows[1].uid, + data_rows[2].uid, + } data = [(data_rows[0], 4, 2), (data_rows[1], 3)] success = project.set_labeling_parameter_overrides(data) @@ -28,8 +35,11 @@ def test_labeling_parameter_overrides(consensus_project_with_batch): for override in updated_overrides: assert isinstance(override.data_row(), DataRow) - data = [(UniqueId(data_rows[0].uid), 1, 2), (UniqueId(data_rows[1].uid), 2), - (UniqueId(data_rows[2].uid), 3)] + data = [ + (UniqueId(data_rows[0].uid), 1, 2), + (UniqueId(data_rows[1].uid), 2), + (UniqueId(data_rows[2].uid), 3), + ] success = project.set_labeling_parameter_overrides(data) assert success updated_overrides = list(project.labeling_parameter_overrides()) @@ -37,9 +47,11 @@ def test_labeling_parameter_overrides(consensus_project_with_batch): assert {o.number_of_labels for o in updated_overrides} == {1, 1, 1} assert {o.priority for o in updated_overrides} == {1, 2, 3} - data = [(GlobalKey(data_rows[0].global_key), 2, 2), - (GlobalKey(data_rows[1].global_key), 3, 3), - (GlobalKey(data_rows[2].global_key), 4)] + data = [ + (GlobalKey(data_rows[0].global_key), 2, 2), + (GlobalKey(data_rows[1].global_key), 3, 3), + (GlobalKey(data_rows[2].global_key), 4), + ] success = project.set_labeling_parameter_overrides(data) assert success updated_overrides = list(project.labeling_parameter_overrides()) @@ -50,21 +62,26 @@ def test_labeling_parameter_overrides(consensus_project_with_batch): with pytest.raises(TypeError) as exc_info: data = [(data_rows[2], "a_string", 3)] project.set_labeling_parameter_overrides(data) - assert str(exc_info.value) == \ - f"Priority must be an int. Found for data_row_identifier {data_rows[2].uid}" + assert ( + str(exc_info.value) + == f"Priority must be an int. Found for data_row_identifier {data_rows[2].uid}" + ) with pytest.raises(TypeError) as exc_info: data = [(data_rows[2].uid, 1)] project.set_labeling_parameter_overrides(data) - assert str(exc_info.value) == \ - f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found for data_row_identifier {data_rows[2].uid}" + assert ( + str(exc_info.value) + == f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found for data_row_identifier {data_rows[2].uid}" + ) def test_set_labeling_priority(consensus_project_with_batch): [project, _, data_rows] = consensus_project_with_batch init_labeling_parameter_overrides = list( - project.labeling_parameter_overrides()) + project.labeling_parameter_overrides() + ) assert len(init_labeling_parameter_overrides) == 3 assert {o.priority for o in init_labeling_parameter_overrides} == {5, 5, 5} diff --git a/libs/labelbox/tests/integration/test_labeling_service.py b/libs/labelbox/tests/integration/test_labeling_service.py index be0b8a6ee..09b5c24a1 100644 --- a/libs/labelbox/tests/integration/test_labeling_service.py +++ b/libs/labelbox/tests/integration/test_labeling_service.py @@ -15,8 +15,12 @@ def test_start_labeling_service(project): def test_request_labeling_service_moe_offline_project( - rand_gen, offline_chat_evaluation_project, chat_evaluation_ontology, - offline_conversational_data_row, model_config): + rand_gen, + offline_chat_evaluation_project, + chat_evaluation_ontology, + offline_conversational_data_row, + model_config, +): project = offline_chat_evaluation_project project.connect_ontology(chat_evaluation_ontology) @@ -25,43 +29,48 @@ def test_request_labeling_service_moe_offline_project( [offline_conversational_data_row.uid], # sample of data row objects ) - project.upsert_instructions('tests/integration/media/sample_pdf.pdf') + project.upsert_instructions("tests/integration/media/sample_pdf.pdf") labeling_service = project.get_labeling_service() labeling_service.request() - assert project.get_labeling_service_status( - ) == LabelingServiceStatus.Requested + assert ( + project.get_labeling_service_status() == LabelingServiceStatus.Requested + ) def test_request_labeling_service_moe_project( - rand_gen, live_chat_evaluation_project_with_new_dataset, - chat_evaluation_ontology, model_config): + rand_gen, + live_chat_evaluation_project_with_new_dataset, + chat_evaluation_ontology, + model_config, +): project = live_chat_evaluation_project_with_new_dataset project.connect_ontology(chat_evaluation_ontology) - project.upsert_instructions('tests/integration/media/sample_pdf.pdf') + project.upsert_instructions("tests/integration/media/sample_pdf.pdf") labeling_service = project.get_labeling_service() with pytest.raises( - LabelboxError, - match= - '[{"errorType":"PROJECT_MODEL_CONFIG","errorMessage":"Project model config is not completed"}]' + LabelboxError, + match='[{"errorType":"PROJECT_MODEL_CONFIG","errorMessage":"Project model config is not completed"}]', ): labeling_service.request() project.add_model_config(model_config.uid) project.set_project_model_setup_complete() labeling_service.request() - assert project.get_labeling_service_status( - ) == LabelingServiceStatus.Requested + assert ( + project.get_labeling_service_status() == LabelingServiceStatus.Requested + ) def test_request_labeling_service_incomplete_requirements(ontology, project): - labeling_service = project.get_labeling_service( + labeling_service = ( + project.get_labeling_service() ) # project fixture is an Image type project - with pytest.raises(ResourceNotFoundError, - match="Associated ontology id could not be found" - ): # No labeling service by default + with pytest.raises( + ResourceNotFoundError, match="Associated ontology id could not be found" + ): # No labeling service by default labeling_service.request() project.connect_ontology(ontology) with pytest.raises(LabelboxError): diff --git a/libs/labelbox/tests/integration/test_legacy_project.py b/libs/labelbox/tests/integration/test_legacy_project.py index fbdf8b252..320a2191d 100644 --- a/libs/labelbox/tests/integration/test_legacy_project.py +++ b/libs/labelbox/tests/integration/test_legacy_project.py @@ -5,9 +5,8 @@ def test_project_dataset(client, rand_gen): with pytest.raises( - ValueError, - match= - "Dataset queue mode is deprecated. Please prefer Batch queue mode." + ValueError, + match="Dataset queue mode is deprecated. Please prefer Batch queue mode.", ): client.create_project( name=rand_gen(str), @@ -30,10 +29,12 @@ def test_project_auto_audit_parameters(client, rand_gen): def test_project_name_parameter(client, rand_gen): - with pytest.raises(ValueError, - match="project name must be a valid string."): + with pytest.raises( + ValueError, match="project name must be a valid string." + ): client.create_project() - with pytest.raises(ValueError, - match="project name must be a valid string."): + with pytest.raises( + ValueError, match="project name must be a valid string." + ): client.create_project(name=" ") diff --git a/libs/labelbox/tests/integration/test_model_config.py b/libs/labelbox/tests/integration/test_model_config.py index 960b096c6..7a060b917 100644 --- a/libs/labelbox/tests/integration/test_model_config.py +++ b/libs/labelbox/tests/integration/test_model_config.py @@ -1,16 +1,22 @@ import pytest from labelbox.exceptions import ResourceNotFoundError + def test_create_model_config(client, valid_model_id): - model_config = client.create_model_config("model_config", valid_model_id, {"param": "value"}) + model_config = client.create_model_config( + "model_config", valid_model_id, {"param": "value"} + ) assert model_config.inference_params["param"] == "value" assert model_config.name == "model_config" assert model_config.model_id == valid_model_id def test_delete_model_config(client, valid_model_id): - model_config_id = client.create_model_config("model_config", valid_model_id, {"param": "value"}) - assert(client.delete_model_config(model_config_id.uid)) + model_config_id = client.create_model_config( + "model_config", valid_model_id, {"param": "value"} + ) + assert client.delete_model_config(model_config_id.uid) + def test_delete_nonexistant_model_config(client): with pytest.raises(ResourceNotFoundError): diff --git a/libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py b/libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py index 2ff5607c3..bb1756afb 100644 --- a/libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py +++ b/libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py @@ -1,11 +1,14 @@ import pytest -def test_create_offline_chat_evaluation_project(client, rand_gen, - offline_chat_evaluation_project, - chat_evaluation_ontology, - offline_conversational_data_row, - model_config): +def test_create_offline_chat_evaluation_project( + client, + rand_gen, + offline_chat_evaluation_project, + chat_evaluation_ontology, + offline_conversational_data_row, + model_config, +): project = offline_chat_evaluation_project assert project diff --git a/libs/labelbox/tests/integration/test_ontology.py b/libs/labelbox/tests/integration/test_ontology.py index 0b6b23e73..91ef74a39 100644 --- a/libs/labelbox/tests/integration/test_ontology.py +++ b/libs/labelbox/tests/integration/test_ontology.py @@ -9,57 +9,67 @@ def test_feature_schema_is_not_archived(client, ontology): - feature_schema_to_check = ontology.normalized['tools'][0] + feature_schema_to_check = ontology.normalized["tools"][0] result = client.is_feature_schema_archived( - ontology.uid, feature_schema_to_check['featureSchemaId']) + ontology.uid, feature_schema_to_check["featureSchemaId"] + ) assert result == False def test_feature_schema_is_archived(client, configured_project_with_label): project, _, _, label = configured_project_with_label ontology = project.ontology() - feature_schema_id = ontology.normalized['tools'][0]['featureSchemaId'] - result = client.delete_feature_schema_from_ontology(ontology.uid, - feature_schema_id) + feature_schema_id = ontology.normalized["tools"][0]["featureSchemaId"] + result = client.delete_feature_schema_from_ontology( + ontology.uid, feature_schema_id + ) assert result.archived == True and result.deleted == False - assert client.is_feature_schema_archived(ontology.uid, - feature_schema_id) == True + assert ( + client.is_feature_schema_archived(ontology.uid, feature_schema_id) + == True + ) def test_is_feature_schema_archived_for_non_existing_feature_schema( - client, ontology): + client, ontology +): with pytest.raises( - Exception, - match="The specified feature schema was not in the ontology"): - client.is_feature_schema_archived(ontology.uid, - 'invalid-feature-schema-id') + Exception, match="The specified feature schema was not in the ontology" + ): + client.is_feature_schema_archived( + ontology.uid, "invalid-feature-schema-id" + ) def test_is_feature_schema_archived_for_non_existing_ontology(client, ontology): - feature_schema_to_unarchive = ontology.normalized['tools'][0] + feature_schema_to_unarchive = ontology.normalized["tools"][0] with pytest.raises( - Exception, - match="Resource 'Ontology' not found for params: 'invalid-ontology'" + Exception, + match="Resource 'Ontology' not found for params: 'invalid-ontology'", ): client.is_feature_schema_archived( - 'invalid-ontology', feature_schema_to_unarchive['featureSchemaId']) + "invalid-ontology", feature_schema_to_unarchive["featureSchemaId"] + ) def test_delete_tool_feature_from_ontology(client, ontology): - feature_schema_to_delete = ontology.normalized['tools'][0] - assert len(ontology.normalized['tools']) == 2 + feature_schema_to_delete = ontology.normalized["tools"][0] + assert len(ontology.normalized["tools"]) == 2 result = client.delete_feature_schema_from_ontology( - ontology.uid, feature_schema_to_delete['featureSchemaId']) + ontology.uid, feature_schema_to_delete["featureSchemaId"] + ) assert result.deleted == True assert result.archived == False updatedOntology = client.get_ontology(ontology.uid) - assert len(updatedOntology.normalized['tools']) == 1 + assert len(updatedOntology.normalized["tools"]) == 1 -@pytest.mark.skip(reason="normalized ontology contains Relationship, " - "which is not finalized yet. introduce this back when" - "Relationship feature is complete and we introduce" - "a Relationship object to the ontology that we can parse") +@pytest.mark.skip( + reason="normalized ontology contains Relationship, " + "which is not finalized yet. introduce this back when" + "Relationship feature is complete and we introduce" + "a Relationship object to the ontology that we can parse" +) def test_from_project_ontology(project) -> None: o = OntologyBuilder.from_project(project) assert o.asdict() == project.ontology().normalized @@ -74,11 +84,12 @@ def test_from_project_ontology(project) -> None: def test_deletes_an_ontology(client): tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] + feature_schema_id = tool.normalized["featureSchemaId"] ontology = client.create_ontology_from_feature_schemas( - name='ontology name', + name="ontology name", feature_schema_ids=[feature_schema_id], - media_type=MediaType.Image) + media_type=MediaType.Image, + ) assert client.delete_unused_ontology(ontology.uid) is None @@ -86,22 +97,25 @@ def test_deletes_an_ontology(client): def test_cant_delete_an_ontology_with_project(client): - project = client.create_project(name="test project", - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) + project = client.create_project( + name="test project", + queue_mode=QueueMode.Batch, + media_type=MediaType.Image, + ) tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] + feature_schema_id = tool.normalized["featureSchemaId"] ontology = client.create_ontology_from_feature_schemas( - name='ontology name', + name="ontology name", feature_schema_ids=[feature_schema_id], - media_type=MediaType.Image) + media_type=MediaType.Image, + ) project.connect_ontology(ontology) with pytest.raises( - Exception, - match= - "Failed to delete the ontology, message: Cannot delete an ontology connected to a project. The ontology is connected to projects: " - + project.uid): + Exception, + match="Failed to delete the ontology, message: Cannot delete an ontology connected to a project. The ontology is connected to projects: " + + project.uid, + ): client.delete_unused_ontology(ontology.uid) project.delete() @@ -110,56 +124,72 @@ def test_cant_delete_an_ontology_with_project(client): def test_inserts_a_feature_schema_at_given_position(client): - tool1 = {'tool': 'polygon', 'name': 'tool1', 'color': 'blue'} - tool2 = {'tool': 'polygon', 'name': 'tool2', 'color': 'blue'} + tool1 = {"tool": "polygon", "name": "tool1", "color": "blue"} + tool2 = {"tool": "polygon", "name": "tool2", "color": "blue"} ontology_normalized_json = {"tools": [tool1, tool2], "classifications": []} - ontology = client.create_ontology(name="ontology", - normalized=ontology_normalized_json, - media_type=MediaType.Image) + ontology = client.create_ontology( + name="ontology", + normalized=ontology_normalized_json, + media_type=MediaType.Image, + ) created_feature_schema = client.upsert_feature_schema(point.asdict()) client.insert_feature_schema_into_ontology( - created_feature_schema.normalized['featureSchemaId'], ontology.uid, 1) + created_feature_schema.normalized["featureSchemaId"], ontology.uid, 1 + ) ontology = client.get_ontology(ontology.uid) - assert ontology.normalized['tools'][1][ - 'schemaNodeId'] == created_feature_schema.normalized['schemaNodeId'] + assert ( + ontology.normalized["tools"][1]["schemaNodeId"] + == created_feature_schema.normalized["schemaNodeId"] + ) client.delete_unused_ontology(ontology.uid) def test_moves_already_added_feature_schema_in_ontology(client): - tool1 = {'tool': 'polygon', 'name': 'tool1', 'color': 'blue'} + tool1 = {"tool": "polygon", "name": "tool1", "color": "blue"} ontology_normalized_json = {"tools": [tool1], "classifications": []} - ontology = client.create_ontology(name="ontology", - normalized=ontology_normalized_json, - media_type=MediaType.Image) + ontology = client.create_ontology( + name="ontology", + normalized=ontology_normalized_json, + media_type=MediaType.Image, + ) created_feature_schema = client.upsert_feature_schema(point.asdict()) - feature_schema_id = created_feature_schema.normalized['featureSchemaId'] - client.insert_feature_schema_into_ontology(feature_schema_id, ontology.uid, - 1) + feature_schema_id = created_feature_schema.normalized["featureSchemaId"] + client.insert_feature_schema_into_ontology( + feature_schema_id, ontology.uid, 1 + ) ontology = client.get_ontology(ontology.uid) - assert ontology.normalized['tools'][1][ - 'schemaNodeId'] == created_feature_schema.normalized['schemaNodeId'] - client.insert_feature_schema_into_ontology(feature_schema_id, ontology.uid, - 0) + assert ( + ontology.normalized["tools"][1]["schemaNodeId"] + == created_feature_schema.normalized["schemaNodeId"] + ) + client.insert_feature_schema_into_ontology( + feature_schema_id, ontology.uid, 0 + ) ontology = client.get_ontology(ontology.uid) - assert ontology.normalized['tools'][0][ - 'schemaNodeId'] == created_feature_schema.normalized['schemaNodeId'] + assert ( + ontology.normalized["tools"][0]["schemaNodeId"] + == created_feature_schema.normalized["schemaNodeId"] + ) client.delete_unused_ontology(ontology.uid) def test_does_not_include_used_ontologies(client): tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] + feature_schema_id = tool.normalized["featureSchemaId"] ontology_with_project = client.create_ontology_from_feature_schemas( - name='ontology name', + name="ontology name", feature_schema_ids=[feature_schema_id], - media_type=MediaType.Image) - project = client.create_project(name="test project", - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) + media_type=MediaType.Image, + ) + project = client.create_project( + name="test project", + queue_mode=QueueMode.Batch, + media_type=MediaType.Image, + ) project.connect_ontology(ontology_with_project) unused_ontologies = client.get_unused_ontologies() @@ -185,10 +215,10 @@ def name_for_read(rand_gen): @pytest.fixture def feature_schema_cat_normalized(name_for_read): yield { - 'tool': 'polygon', - 'name': name_for_read, - 'color': 'black', - 'classifications': [], + "tool": "polygon", + "name": name_for_read, + "color": "black", + "classifications": [], } @@ -199,26 +229,29 @@ def feature_schema_for_read(client, feature_schema_cat_normalized): client.delete_unused_feature_schema(feature_schema.uid) -def test_feature_schema_create_read(client, feature_schema_for_read, - name_for_read): +def test_feature_schema_create_read( + client, feature_schema_for_read, name_for_read +): created_feature_schema = feature_schema_for_read queried_feature_schema = client.get_feature_schema( - created_feature_schema.uid) + created_feature_schema.uid + ) for attr in Entity.FeatureSchema.fields(): - assert _get_attr_stringify_json(created_feature_schema, - attr) == _get_attr_stringify_json( - queried_feature_schema, attr) + assert _get_attr_stringify_json( + created_feature_schema, attr + ) == _get_attr_stringify_json(queried_feature_schema, attr) time.sleep(3) # Slight delay for searching queried_feature_schemas = list(client.get_feature_schemas(name_for_read)) - assert [feature_schema.name for feature_schema in queried_feature_schemas - ] == [name_for_read] + assert [ + feature_schema.name for feature_schema in queried_feature_schemas + ] == [name_for_read] queried_feature_schema = queried_feature_schemas[0] for attr in Entity.FeatureSchema.fields(): - assert _get_attr_stringify_json(created_feature_schema, - attr) == _get_attr_stringify_json( - queried_feature_schema, attr) + assert _get_attr_stringify_json( + created_feature_schema, attr + ) == _get_attr_stringify_json(queried_feature_schema, attr) def test_ontology_create_read( @@ -228,61 +261,67 @@ def test_ontology_create_read( ontology_name = f"test-ontology-{rand_gen(str)}" tool_name = f"test-ontology-tool-{rand_gen(str)}" feature_schema_cat_normalized = { - 'tool': 'polygon', - 'name': tool_name, - 'color': 'black', - 'classifications': [], + "tool": "polygon", + "name": tool_name, + "color": "black", + "classifications": [], } feature_schema = client.create_feature_schema(feature_schema_cat_normalized) created_ontology = client.create_ontology_from_feature_schemas( name=ontology_name, feature_schema_ids=[feature_schema.uid], - media_type=MediaType.Image) - tool_normalized = created_ontology.normalized['tools'][0] + media_type=MediaType.Image, + ) + tool_normalized = created_ontology.normalized["tools"][0] for k, v in feature_schema_cat_normalized.items(): assert tool_normalized[k] == v - assert tool_normalized['schemaNodeId'] is not None - assert tool_normalized['featureSchemaId'] == feature_schema.uid + assert tool_normalized["schemaNodeId"] is not None + assert tool_normalized["featureSchemaId"] == feature_schema.uid queried_ontology = client.get_ontology(created_ontology.uid) for attr in Entity.Ontology.fields(): - assert _get_attr_stringify_json(created_ontology, - attr) == _get_attr_stringify_json( - queried_ontology, attr) + assert _get_attr_stringify_json( + created_ontology, attr + ) == _get_attr_stringify_json(queried_ontology, attr) time.sleep(3) # Slight delay for searching queried_ontologies = list(client.get_ontologies(ontology_name)) assert [ontology.name for ontology in queried_ontologies] == [ontology_name] queried_ontology = queried_ontologies[0] for attr in Entity.Ontology.fields(): - assert _get_attr_stringify_json(created_ontology, - attr) == _get_attr_stringify_json( - queried_ontology, attr) + assert _get_attr_stringify_json( + created_ontology, attr + ) == _get_attr_stringify_json(queried_ontology, attr) def test_unarchive_feature_schema_node(client, ontology): - feature_schema_to_unarchive = ontology.normalized['tools'][0] + feature_schema_to_unarchive = ontology.normalized["tools"][0] result = client.unarchive_feature_schema_node( - ontology.uid, feature_schema_to_unarchive['featureSchemaId']) + ontology.uid, feature_schema_to_unarchive["featureSchemaId"] + ) assert result == None def test_unarchive_feature_schema_node_for_non_existing_feature_schema( - client, ontology): + client, ontology +): with pytest.raises( - Exception, - match= - "Failed to find feature schema node by id: invalid-feature-schema-id" + Exception, + match="Failed to find feature schema node by id: invalid-feature-schema-id", ): - client.unarchive_feature_schema_node(ontology.uid, - 'invalid-feature-schema-id') + client.unarchive_feature_schema_node( + ontology.uid, "invalid-feature-schema-id" + ) def test_unarchive_feature_schema_node_for_non_existing_ontology( - client, ontology): - feature_schema_to_unarchive = ontology.normalized['tools'][0] - with pytest.raises(Exception, - match="Failed to find ontology by id: invalid-ontology"): + client, ontology +): + feature_schema_to_unarchive = ontology.normalized["tools"][0] + with pytest.raises( + Exception, match="Failed to find ontology by id: invalid-ontology" + ): client.unarchive_feature_schema_node( - 'invalid-ontology', feature_schema_to_unarchive['featureSchemaId']) + "invalid-ontology", feature_schema_to_unarchive["featureSchemaId"] + ) diff --git a/libs/labelbox/tests/integration/test_project.py b/libs/labelbox/tests/integration/test_project.py index 7b63ee391..a38fa2b5d 100644 --- a/libs/labelbox/tests/integration/test_project.py +++ b/libs/labelbox/tests/integration/test_project.py @@ -71,7 +71,9 @@ def delete_tag(tag_id: str): id } } - """, {"tag_id": tag_id}) + """, + {"tag_id": tag_id}, + ) return res org = client.get_organization() @@ -89,7 +91,7 @@ def delete_tag(tag_id: str): tagA = client.get_organization().create_resource_tag(tag) assert tagA.text == textA - assert '#' + tagA.color == colorA + assert "#" + tagA.color == colorA assert tagA.uid is not None tags = org.get_resource_tags() @@ -98,7 +100,7 @@ def delete_tag(tag_id: str): tagB = client.get_organization().create_resource_tag(tagB) assert tagB.text == textB - assert '#' + tagB.color == colorB + assert "#" + tagB.color == colorB assert tagB.uid is not None tags = client.get_organization().get_resource_tags() @@ -107,7 +109,8 @@ def delete_tag(tag_id: str): assert lenB > lenA project_resource_tag = client.get_project( - p1.uid).update_project_resource_tags([str(tagA.uid)]) + p1.uid + ).update_project_resource_tags([str(tagA.uid)]) assert len(project_resource_tag) == 1 assert project_resource_tag[0].uid == tagA.uid @@ -136,75 +139,84 @@ def test_extend_reservations(project): project.extend_reservations("InvalidQueueType") -@pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", - reason="new mutation does not work for onprem") +@pytest.mark.skipif( + condition=os.environ["LABELBOX_TEST_ENVIRON"] == "onprem", + reason="new mutation does not work for onprem", +) def test_attach_instructions(client, project): with pytest.raises(ValueError) as execinfo: - project.upsert_instructions('tests/integration/media/sample_pdf.pdf') - assert str( - execinfo.value - ) == "Cannot attach instructions to a project that has not been set up." + project.upsert_instructions("tests/integration/media/sample_pdf.pdf") + assert ( + str(execinfo.value) + == "Cannot attach instructions to a project that has not been set up." + ) editor = list( - client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] + client.get_labeling_frontends(where=LabelingFrontend.name == "editor") + )[0] empty_ontology = {"tools": [], "classifications": []} project.setup(editor, empty_ontology) - project.upsert_instructions('tests/integration/media/sample_pdf.pdf') + project.upsert_instructions("tests/integration/media/sample_pdf.pdf") time.sleep(3) - assert project.ontology().normalized['projectInstructions'] is not None + assert project.ontology().normalized["projectInstructions"] is not None with pytest.raises(ValueError) as exc_info: - project.upsert_instructions('/tmp/file.invalid_file_extension') + project.upsert_instructions("/tmp/file.invalid_file_extension") assert "instructions_file must be a pdf or html file. Found" in str( - exc_info.value) + exc_info.value + ) -@pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", - reason="new mutation does not work for onprem") +@pytest.mark.skipif( + condition=os.environ["LABELBOX_TEST_ENVIRON"] == "onprem", + reason="new mutation does not work for onprem", +) def test_html_instructions(project_with_empty_ontology): - html_file_path = '/tmp/instructions.html' + html_file_path = "/tmp/instructions.html" sample_html_str = "" - with open(html_file_path, 'w') as file: + with open(html_file_path, "w") as file: file.write(sample_html_str) project_with_empty_ontology.upsert_instructions(html_file_path) updated_ontology = project_with_empty_ontology.ontology().normalized - instructions = updated_ontology.pop('projectInstructions') + instructions = updated_ontology.pop("projectInstructions") assert requests.get(instructions).text == sample_html_str -@pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", - reason="new mutation does not work for onprem") +@pytest.mark.skipif( + condition=os.environ["LABELBOX_TEST_ENVIRON"] == "onprem", + reason="new mutation does not work for onprem", +) def test_same_ontology_after_instructions( - configured_project_with_complex_ontology): + configured_project_with_complex_ontology, +): project, _ = configured_project_with_complex_ontology initial_ontology = project.ontology().normalized - project.upsert_instructions('tests/assets/loremipsum.pdf') + project.upsert_instructions("tests/assets/loremipsum.pdf") updated_ontology = project.ontology().normalized - instructions = updated_ontology.pop('projectInstructions') + instructions = updated_ontology.pop("projectInstructions") assert initial_ontology == updated_ontology assert instructions is not None def test_batches(project: Project, dataset: Dataset, image_url): - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image" - }, - ] * 2) + task = dataset.create_data_rows( + [ + {"row_data": image_url, "external_id": "my-image"}, + ] + * 2 + ) task.wait_till_done() export_task = dataset.export() export_task.wait_till_done() stream = export_task.get_buffered_stream() data_rows = [dr.json["data_row"]["id"] for dr in stream] - batch_one = f'batch one {uuid.uuid4()}' - batch_two = f'batch two {uuid.uuid4()}' + batch_one = f"batch one {uuid.uuid4()}" + batch_two = f"batch two {uuid.uuid4()}" project.create_batch(batch_one, [data_rows[0]]) project.create_batch(batch_two, [data_rows[1]]) @@ -212,19 +224,19 @@ def test_batches(project: Project, dataset: Dataset, image_url): assert names == {batch_one, batch_two} -@pytest.mark.parametrize('data_rows', [2], indirect=True) +@pytest.mark.parametrize("data_rows", [2], indirect=True) def test_create_batch_with_global_keys_sync(project: Project, data_rows): global_keys = [dr.global_key for dr in data_rows] - batch_name = f'batch {uuid.uuid4()}' + batch_name = f"batch {uuid.uuid4()}" batch = project.create_batch(batch_name, global_keys=global_keys) assert batch.size == len(set(data_rows)) -@pytest.mark.parametrize('data_rows', [2], indirect=True) +@pytest.mark.parametrize("data_rows", [2], indirect=True) def test_create_batch_with_global_keys_async(project: Project, data_rows): global_keys = [dr.global_key for dr in data_rows] - batch_name = f'batch {uuid.uuid4()}' + batch_name = f"batch {uuid.uuid4()}" batch = project._create_batch_async(batch_name, global_keys=global_keys) assert batch.size == len(set(data_rows)) @@ -243,28 +255,35 @@ def test_media_type(client, project: Project, rand_gen): for media_type in MediaType.get_supported_members(): # Exclude LLM media types for now, as they are not supported if MediaType[media_type] in [ - MediaType.LLMPromptCreation, - MediaType.LLMPromptResponseCreation, MediaType.LLM + MediaType.LLMPromptCreation, + MediaType.LLMPromptResponseCreation, + MediaType.LLM, ]: continue - project = client.create_project(name=rand_gen(str), - media_type=MediaType[media_type]) + project = client.create_project( + name=rand_gen(str), media_type=MediaType[media_type] + ) assert project.media_type == MediaType[media_type] project.delete() def test_queue_mode(client, rand_gen): - project = client.create_project(name=rand_gen(str)) # defaults to benchmark and consensus + project = client.create_project( + name=rand_gen(str) + ) # defaults to benchmark and consensus assert project.auto_audit_number_of_labels == 3 assert project.auto_audit_percentage == 0 - project = client.create_project(name=rand_gen(str), quality_modes=[QualityMode.Benchmark]) + project = client.create_project( + name=rand_gen(str), quality_modes=[QualityMode.Benchmark] + ) assert project.auto_audit_number_of_labels == 1 assert project.auto_audit_percentage == 1 project = client.create_project( - name=rand_gen(str), quality_modes=[QualityMode.Benchmark, QualityMode.Consensus] + name=rand_gen(str), + quality_modes=[QualityMode.Benchmark, QualityMode.Consensus], ) assert project.auto_audit_number_of_labels == 3 assert project.auto_audit_percentage == 0 @@ -282,14 +301,18 @@ def test_label_count(client, configured_batch_project_with_label): def test_clone(client, project, rand_gen): # cannot clone unknown project media type - project = client.create_project(name=rand_gen(str), - media_type=MediaType.Image) + project = client.create_project( + name=rand_gen(str), media_type=MediaType.Image + ) cloned_project = project.clone() assert cloned_project.description == project.description assert cloned_project.media_type == project.media_type assert cloned_project.queue_mode == project.queue_mode - assert cloned_project.auto_audit_number_of_labels == project.auto_audit_number_of_labels + assert ( + cloned_project.auto_audit_number_of_labels + == project.auto_audit_number_of_labels + ) assert cloned_project.auto_audit_percentage == project.auto_audit_percentage assert cloned_project.get_label_count() == 0 diff --git a/libs/labelbox/tests/integration/test_project_model_config.py b/libs/labelbox/tests/integration/test_project_model_config.py index 7b564b2af..2d783f62b 100644 --- a/libs/labelbox/tests/integration/test_project_model_config.py +++ b/libs/labelbox/tests/integration/test_project_model_config.py @@ -2,52 +2,67 @@ from labelbox.exceptions import ResourceNotFoundError -def test_add_single_model_config(live_chat_evaluation_project_with_new_dataset, - model_config): +def test_add_single_model_config( + live_chat_evaluation_project_with_new_dataset, model_config +): configured_project = live_chat_evaluation_project_with_new_dataset project_model_config_id = configured_project.add_model_config( - model_config.uid) + model_config.uid + ) - assert set(config.uid - for config in configured_project.project_model_configs()) == set( - [project_model_config_id]) + assert set( + config.uid for config in configured_project.project_model_configs() + ) == set([project_model_config_id]) assert configured_project.delete_project_model_config( - project_model_config_id) + project_model_config_id + ) -def test_add_multiple_model_config(client, rand_gen, - live_chat_evaluation_project_with_new_dataset, - model_config, valid_model_id): +def test_add_multiple_model_config( + client, + rand_gen, + live_chat_evaluation_project_with_new_dataset, + model_config, + valid_model_id, +): configured_project = live_chat_evaluation_project_with_new_dataset - second_model_config = client.create_model_config(rand_gen(str), - valid_model_id, - {"param": "value"}) + second_model_config = client.create_model_config( + rand_gen(str), valid_model_id, {"param": "value"} + ) first_project_model_config_id = configured_project.add_model_config( - model_config.uid) + model_config.uid + ) second_project_model_config_id = configured_project.add_model_config( - second_model_config.uid) + second_model_config.uid + ) expected_model_configs = set( - [first_project_model_config_id, second_project_model_config_id]) + [first_project_model_config_id, second_project_model_config_id] + ) - assert set( - config.uid for config in configured_project.project_model_configs() - ) == expected_model_configs + assert ( + set(config.uid for config in configured_project.project_model_configs()) + == expected_model_configs + ) for project_model_config_id in expected_model_configs: assert configured_project.delete_project_model_config( - project_model_config_id) + project_model_config_id + ) -def test_delete_project_model_config(live_chat_evaluation_project_with_new_dataset, - model_config): +def test_delete_project_model_config( + live_chat_evaluation_project_with_new_dataset, model_config +): configured_project = live_chat_evaluation_project_with_new_dataset assert configured_project.delete_project_model_config( - configured_project.add_model_config(model_config.uid)) + configured_project.add_model_config(model_config.uid) + ) assert not len(configured_project.project_model_configs()) def test_delete_nonexistant_project_model_config(configured_project): with pytest.raises(ResourceNotFoundError): configured_project.delete_project_model_config( - "nonexistant_project_model_config") + "nonexistant_project_model_config" + ) diff --git a/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py b/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py index d48514024..1c3e68c9a 100644 --- a/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py +++ b/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py @@ -4,24 +4,23 @@ def test_live_chat_evaluation_project( - live_chat_evaluation_project_with_new_dataset, model_config): - + live_chat_evaluation_project_with_new_dataset, model_config +): project = live_chat_evaluation_project_with_new_dataset project.set_project_model_setup_complete() assert bool(project.model_setup_complete) is True with pytest.raises( - expected_exception=LabelboxError, - match= - "Cannot create model config for project because model setup is complete" + expected_exception=LabelboxError, + match="Cannot create model config for project because model setup is complete", ): project.add_model_config(model_config.uid) def test_live_chat_evaluation_project_delete_cofig( - live_chat_evaluation_project_with_new_dataset, model_config): - + live_chat_evaluation_project_with_new_dataset, model_config +): project = live_chat_evaluation_project_with_new_dataset project_model_config_id = project.add_model_config(model_config.uid) assert project_model_config_id @@ -37,30 +36,27 @@ def test_live_chat_evaluation_project_delete_cofig( assert bool(project.model_setup_complete) is True with pytest.raises( - expected_exception=LabelboxError, - match= - "Cannot create model config for project because model setup is complete" + expected_exception=LabelboxError, + match="Cannot create model config for project because model setup is complete", ): project_model_config.delete() -def test_offline_chat_evaluation_project(offline_chat_evaluation_project, - model_config): - +def test_offline_chat_evaluation_project( + offline_chat_evaluation_project, model_config +): project = offline_chat_evaluation_project with pytest.raises( - expected_exception=OperationNotAllowedException, - match= - "Only live model chat evaluation projects can complete model setup" + expected_exception=OperationNotAllowedException, + match="Only live model chat evaluation projects can complete model setup", ): project.set_project_model_setup_complete() def test_any_other_project(project, model_config): with pytest.raises( - expected_exception=OperationNotAllowedException, - match= - "Only live model chat evaluation projects can complete model setup" + expected_exception=OperationNotAllowedException, + match="Only live model chat evaluation projects can complete model setup", ): project.set_project_model_setup_complete() diff --git a/libs/labelbox/tests/integration/test_project_setup.py b/libs/labelbox/tests/integration/test_project_setup.py index 8404b0e50..faadea228 100644 --- a/libs/labelbox/tests/integration/test_project_setup.py +++ b/libs/labelbox/tests/integration/test_project_setup.py @@ -9,16 +9,17 @@ def simple_ontology(): - classifications = [{ - "name": "test_ontology", - "instructions": "Which class is this?", - "type": "radio", - "options": [{ - "value": c, - "label": c - } for c in ["one", "two", "three"]], - "required": True, - }] + classifications = [ + { + "name": "test_ontology", + "instructions": "Which class is this?", + "type": "radio", + "options": [ + {"value": c, "label": c} for c in ["one", "two", "three"] + ], + "required": True, + } + ] return {"tools": [], "classifications": classifications} @@ -26,7 +27,8 @@ def simple_ontology(): def test_project_setup(project) -> None: client = project.client labeling_frontends = list( - client.get_labeling_frontends(where=LabelingFrontend.name == 'Editor')) + client.get_labeling_frontends(where=LabelingFrontend.name == "Editor") + ) assert len(labeling_frontends) labeling_frontend = labeling_frontends[0] @@ -64,12 +66,14 @@ def test_project_editor_setup(client, project, rand_gen): assert project.ontology().name == ontology_name # Make sure that setup only creates one ontology time.sleep(3) # Search takes a second - assert [ontology.name for ontology in client.get_ontologies(ontology_name) - ] == [ontology_name] + assert [ + ontology.name for ontology in client.get_ontologies(ontology_name) + ] == [ontology_name] def test_project_connect_ontology_cant_call_multiple_times( - client, project, rand_gen): + client, project, rand_gen +): ontology_name = f"test_project_editor_setup_ontology_name-{rand_gen(str)}" ontology = client.create_ontology(ontology_name, simple_ontology()) project.connect_ontology(ontology) diff --git a/libs/labelbox/tests/integration/test_prompt_response_generation_project.py b/libs/labelbox/tests/integration/test_prompt_response_generation_project.py index 20d42d92c..1373ee470 100644 --- a/libs/labelbox/tests/integration/test_prompt_response_generation_project.py +++ b/libs/labelbox/tests/integration/test_prompt_response_generation_project.py @@ -5,19 +5,25 @@ from labelbox.schema.ontology_kind import OntologyKind from labelbox.exceptions import MalformedQueryException + @pytest.mark.parametrize( "prompt_response_ontology, prompt_response_generation_project_with_new_dataset", [ (MediaType.LLMPromptCreation, MediaType.LLMPromptCreation), - (MediaType.LLMPromptResponseCreation, MediaType.LLMPromptResponseCreation), + ( + MediaType.LLMPromptResponseCreation, + MediaType.LLMPromptResponseCreation, + ), ], - indirect=True + indirect=True, ) def test_prompt_response_generation_ontology_project( - client, prompt_response_ontology, - prompt_response_generation_project_with_new_dataset, - response_data_row, rand_gen): - + client, + prompt_response_ontology, + prompt_response_generation_project_with_new_dataset, + response_data_row, + rand_gen, +): ontology = prompt_response_ontology assert ontology @@ -35,36 +41,41 @@ def test_prompt_response_generation_ontology_project( assert project.ontology().name == ontology.name with pytest.raises( - ValueError, - match="Cannot create batches for auto data generation projects"): + ValueError, + match="Cannot create batches for auto data generation projects", + ): project.create_batch( rand_gen(str), [response_data_row.uid], # sample of data row objects ) with pytest.raises( - ValueError, - match="Cannot create batches for auto data generation projects"): - with patch('labelbox.schema.project.MAX_SYNC_BATCH_ROW_COUNT', - new=0): # force to async - + ValueError, + match="Cannot create batches for auto data generation projects", + ): + with patch( + "labelbox.schema.project.MAX_SYNC_BATCH_ROW_COUNT", new=0 + ): # force to async project.create_batch( rand_gen(str), - [response_data_row.uid - ], # sample of data row objects + [response_data_row.uid], # sample of data row objects ) + @pytest.mark.parametrize( "prompt_response_ontology, prompt_response_generation_project_with_dataset_id", [ (MediaType.LLMPromptCreation, MediaType.LLMPromptCreation), - (MediaType.LLMPromptResponseCreation, MediaType.LLMPromptResponseCreation), + ( + MediaType.LLMPromptResponseCreation, + MediaType.LLMPromptResponseCreation, + ), ], - indirect=True + indirect=True, ) def test_prompt_response_generation_ontology_project_with_existing_dataset( - prompt_response_ontology, - prompt_response_generation_project_with_dataset_id): + prompt_response_ontology, prompt_response_generation_project_with_dataset_id +): ontology = prompt_response_ontology project = prompt_response_generation_project_with_dataset_id @@ -77,48 +88,55 @@ def test_prompt_response_generation_ontology_project_with_existing_dataset( @pytest.fixture def classification_json(): - classifications = [{ - 'featureSchemaId': None, - 'kind': 'Prompt', - 'minCharacters': 2, - 'maxCharacters': 10, - 'name': 'prompt text', - 'instructions': 'prompt text', - 'required': True, - 'schemaNodeId': None, - "scope": "global", - 'type': 'prompt', - 'options': [] - }, { - 'featureSchemaId': None, - 'kind': 'ResponseCheckboxQuestion', - 'name': 'response checklist', - 'instructions': 'response checklist', - 'options': [{'featureSchemaId': None, - 'kind': 'ResponseCheckboxOption', - 'label': 'response checklist option', - 'schemaNodeId': None, - 'position': 0, - 'value': 'option_1'}], - 'required': True, - 'schemaNodeId': None, - "scope": "global", - 'type': 'response-checklist' - }, { - 'featureSchemaId': None, - 'kind': 'ResponseText', - 'maxCharacters': 10, - 'minCharacters': 1, - 'name': 'response text', - 'instructions': 'response text', - 'required': True, - 'schemaNodeId': None, - "scope": "global", - 'type': 'response-text', - 'options': [] - } + classifications = [ + { + "featureSchemaId": None, + "kind": "Prompt", + "minCharacters": 2, + "maxCharacters": 10, + "name": "prompt text", + "instructions": "prompt text", + "required": True, + "schemaNodeId": None, + "scope": "global", + "type": "prompt", + "options": [], + }, + { + "featureSchemaId": None, + "kind": "ResponseCheckboxQuestion", + "name": "response checklist", + "instructions": "response checklist", + "options": [ + { + "featureSchemaId": None, + "kind": "ResponseCheckboxOption", + "label": "response checklist option", + "schemaNodeId": None, + "position": 0, + "value": "option_1", + } + ], + "required": True, + "schemaNodeId": None, + "scope": "global", + "type": "response-checklist", + }, + { + "featureSchemaId": None, + "kind": "ResponseText", + "maxCharacters": 10, + "minCharacters": 1, + "name": "response text", + "instructions": "response text", + "required": True, + "schemaNodeId": None, + "scope": "global", + "type": "response-text", + "options": [], + }, ] - + return classifications @@ -139,7 +157,7 @@ def ontology_from_feature_ids(client, features_from_json): ontology = client.create_ontology_from_feature_schemas( name="test-prompt_response_creation{rand_gen(str)}", feature_schema_ids=feature_ids, - media_type=MediaType.LLMPromptResponseCreation + media_type=MediaType.LLMPromptResponseCreation, ) yield ontology @@ -147,18 +165,22 @@ def ontology_from_feature_ids(client, features_from_json): client.delete_unused_ontology(ontology.uid) -def test_ontology_create_feature_schema(ontology_from_feature_ids, - features_from_json, classification_json): +def test_ontology_create_feature_schema( + ontology_from_feature_ids, features_from_json, classification_json +): created_ontology = ontology_from_feature_ids feature_schema_ids = {f.uid for f in features_from_json} - classifications_normalized = created_ontology.normalized['classifications'] + classifications_normalized = created_ontology.normalized["classifications"] classifications = classification_json for classification in classifications: generated_tool = next( - c for c in classifications_normalized if c['name'] == classification['name']) - assert generated_tool['schemaNodeId'] is not None - assert generated_tool['featureSchemaId'] in feature_schema_ids - assert generated_tool['type'] == classification['type'] - assert generated_tool['name'] == classification['name'] - assert generated_tool['required'] == classification['required'] + c + for c in classifications_normalized + if c["name"] == classification["name"] + ) + assert generated_tool["schemaNodeId"] is not None + assert generated_tool["featureSchemaId"] in feature_schema_ids + assert generated_tool["type"] == classification["type"] + assert generated_tool["name"] == classification["name"] + assert generated_tool["required"] == classification["required"] diff --git a/libs/labelbox/tests/integration/test_response_creation_project.py b/libs/labelbox/tests/integration/test_response_creation_project.py index 76ba12d54..d7f9a1e46 100644 --- a/libs/labelbox/tests/integration/test_response_creation_project.py +++ b/libs/labelbox/tests/integration/test_response_creation_project.py @@ -3,11 +3,17 @@ from labelbox.schema.ontology_kind import OntologyKind -@pytest.mark.parametrize("prompt_response_ontology", [OntologyKind.ResponseCreation], indirect=True) -def test_create_response_creation_project(client, rand_gen, - response_creation_project, - prompt_response_ontology, - response_data_row): + +@pytest.mark.parametrize( + "prompt_response_ontology", [OntologyKind.ResponseCreation], indirect=True +) +def test_create_response_creation_project( + client, + rand_gen, + response_creation_project, + prompt_response_ontology, + response_data_row, +): project: Project = response_creation_project assert project @@ -21,4 +27,4 @@ def test_create_response_creation_project(client, rand_gen, rand_gen(str), [response_data_row.uid], # sample of data row objects ) - assert batch \ No newline at end of file + assert batch diff --git a/libs/labelbox/tests/integration/test_send_to_annotate.py b/libs/labelbox/tests/integration/test_send_to_annotate.py index fd358324f..3ba4d13a5 100644 --- a/libs/labelbox/tests/integration/test_send_to_annotate.py +++ b/libs/labelbox/tests/integration/test_send_to_annotate.py @@ -1,11 +1,16 @@ from labelbox import UniqueIds, Project, Ontology, Client -from labelbox.schema.conflict_resolution_strategy import ConflictResolutionStrategy +from labelbox.schema.conflict_resolution_strategy import ( + ConflictResolutionStrategy, +) from typing import List def test_send_to_annotate_include_annotations( - client: Client, configured_batch_project_with_label: Project, - project_pack: List[Project], ontology: Ontology): + client: Client, + configured_batch_project_with_label: Project, + project_pack: List[Project], + ontology: Ontology, +): [source_project, _, data_row, _] = configured_batch_project_with_label destination_project: Project = project_pack[0] @@ -14,18 +19,22 @@ def test_send_to_annotate_include_annotations( # build an ontology mapping using the top level tools src_feature_schema_ids = list( - tool.feature_schema_id for tool in src_ontology.tools()) + tool.feature_schema_id for tool in src_ontology.tools() + ) dest_ontology = destination_project.ontology() dest_feature_schema_ids = list( - tool.feature_schema_id for tool in dest_ontology.tools()) + tool.feature_schema_id for tool in dest_ontology.tools() + ) # create a dictionary of feature schema id to itself - ontology_mapping = dict(zip(src_feature_schema_ids, - dest_feature_schema_ids)) + ontology_mapping = dict( + zip(src_feature_schema_ids, dest_feature_schema_ids) + ) try: queues = destination_project.task_queues() initial_review_task = next( - q for q in queues if q.name == "Initial review task") + q for q in queues if q.name == "Initial review task" + ) # Send the data row to the new project task = client.send_to_annotate_from_catalog( @@ -34,13 +43,11 @@ def test_send_to_annotate_include_annotations( batch_name="test-batch", data_rows=UniqueIds([data_row.uid]), params={ - "source_project_id": - source_project.uid, - "annotations_ontology_mapping": - ontology_mapping, - "override_existing_annotations_rule": - ConflictResolutionStrategy.OverrideWithAnnotations - }) + "source_project_id": source_project.uid, + "annotations_ontology_mapping": ontology_mapping, + "override_existing_annotations_rule": ConflictResolutionStrategy.OverrideWithAnnotations, + }, + ) task.wait_till_done() @@ -57,7 +64,7 @@ def test_send_to_annotate_include_annotations( assert destination_data_rows[0] == data_row.uid # Verify annotations were copied into the destination project - destination_project_labels = (list(destination_project.labels())) + destination_project_labels = list(destination_project.labels()) assert len(destination_project_labels) == 1 finally: destination_project.delete() diff --git a/libs/labelbox/tests/integration/test_task.py b/libs/labelbox/tests/integration/test_task.py index b0eac2fa1..da89e4bb0 100644 --- a/libs/labelbox/tests/integration/test_task.py +++ b/libs/labelbox/tests/integration/test_task.py @@ -9,42 +9,50 @@ def test_task_errors(dataset, image_url, snapshot): client = dataset.client - task = dataset.create_data_rows([ - { - DataRow.row_data: - image_url, - DataRow.metadata_fields: [ - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, - value='some msg'), - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, - value='some msg 2') - ] - }, - ]) + task = dataset.create_data_rows( + [ + { + DataRow.row_data: image_url, + DataRow.metadata_fields: [ + DataRowMetadataField( + schema_id=TEXT_SCHEMA_ID, value="some msg" + ), + DataRowMetadataField( + schema_id=TEXT_SCHEMA_ID, value="some msg 2" + ), + ], + }, + ] + ) assert task in client.get_user().created_tasks() task.wait_till_done() assert len(task.failed_data_rows) == 1 - assert "A schemaId can only be specified once per DataRow : [cko8s9r5v0001h2dk9elqdidh]" in task.failed_data_rows[ - 0]['message'] - assert len(task.failed_data_rows[0]['failedDataRows'][0]['metadata']) == 2 + assert ( + "A schemaId can only be specified once per DataRow : [cko8s9r5v0001h2dk9elqdidh]" + in task.failed_data_rows[0]["message"] + ) + assert len(task.failed_data_rows[0]["failedDataRows"][0]["metadata"]) == 2 dt = client.get_task_by_id(task.uid) assert dt.status == "COMPLETE" assert len(dt.errors) == 1 - assert dt.errors[0]['message'].startswith( - "A schemaId can only be specified once per DataRow") + assert dt.errors[0]["message"].startswith( + "A schemaId can only be specified once per DataRow" + ) assert dt.result is None def test_task_success_json(dataset, image_url, snapshot): client = dataset.client - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url, - }, - ]) + task = dataset.create_data_rows( + [ + { + DataRow.row_data: image_url, + }, + ] + ) assert task in client.get_user().created_tasks() task.wait_till_done() assert task.status == "COMPLETE" @@ -54,14 +62,16 @@ def test_task_success_json(dataset, image_url, snapshot): assert task.result_url is not None assert isinstance(task.result_url, str) task_result = task.result[0] - assert 'id' in task_result and isinstance(task_result['id'], str) - assert 'row_data' in task_result and isinstance(task_result['row_data'], - str) + assert "id" in task_result and isinstance(task_result["id"], str) + assert "row_data" in task_result and isinstance( + task_result["row_data"], str + ) snapshot.snapshot_dir = INTEGRATION_SNAPSHOT_DIRECTORY - task_result['id'] = 'DUMMY_ID' - task_result['row_data'] = 'https://dummy.url' - snapshot.assert_match(json.dumps(task_result), - 'test_task.test_task_success_json.json') + task_result["id"] = "DUMMY_ID" + task_result["row_data"] = "https://dummy.url" + snapshot.assert_match( + json.dumps(task_result), "test_task.test_task_success_json.json" + ) assert len(task.result) dt = client.get_task_by_id(task.uid) diff --git a/libs/labelbox/tests/integration/test_task_queue.py b/libs/labelbox/tests/integration/test_task_queue.py index 2a6ca45d8..835f67219 100644 --- a/libs/labelbox/tests/integration/test_task_queue.py +++ b/libs/labelbox/tests/integration/test_task_queue.py @@ -7,7 +7,8 @@ def test_get_task_queue(project: Project): task_queues = project.task_queues() assert len(task_queues) == 3 review_queue = next( - tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE") + tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE" + ) assert review_queue @@ -23,6 +24,7 @@ def test_get_overview_no_details(project: Project): assert isinstance(po.labeled, int) assert isinstance(po.total_data_rows, int) + def test_get_overview_with_details(project: Project): po = project.get_overview(details=True) @@ -37,20 +39,23 @@ def test_get_overview_with_details(project: Project): assert isinstance(po.labeled, int) assert isinstance(po.total_data_rows, int) + def _validate_moved(project, queue_name, data_row_count): timeout_seconds = 30 sleep_time = 2 while True: task_queues = project.task_queues() review_queue = next( - tq for tq in task_queues if tq.queue_type == queue_name) + tq for tq in task_queues if tq.queue_type == queue_name + ) if review_queue.data_row_count == data_row_count: break if timeout_seconds <= 0: raise AssertionError( - "Timed out expecting data_row_count of 1 in the review queue") + "Timed out expecting data_row_count of 1 in the review queue" + ) timeout_seconds -= sleep_time time.sleep(sleep_time) @@ -61,18 +66,23 @@ def test_move_to_task(configured_batch_project_with_label): task_queues = project.task_queues() review_queue = next( - tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE") + tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE" + ) project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid) _validate_moved(project, "MANUAL_REVIEW_QUEUE", 1) review_queue = next( - tq for tq in task_queues if tq.queue_type == "MANUAL_REWORK_QUEUE") - project.move_data_rows_to_task_queue(GlobalKeys([data_row.global_key]), - review_queue.uid) + tq for tq in task_queues if tq.queue_type == "MANUAL_REWORK_QUEUE" + ) + project.move_data_rows_to_task_queue( + GlobalKeys([data_row.global_key]), review_queue.uid + ) _validate_moved(project, "MANUAL_REWORK_QUEUE", 1) review_queue = next( - tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE") - project.move_data_rows_to_task_queue(UniqueIds([data_row.uid]), - review_queue.uid) + tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE" + ) + project.move_data_rows_to_task_queue( + UniqueIds([data_row.uid]), review_queue.uid + ) _validate_moved(project, "MANUAL_REVIEW_QUEUE", 1) diff --git a/libs/labelbox/tests/integration/test_user_and_org.py b/libs/labelbox/tests/integration/test_user_and_org.py index ca158527c..7bb72051f 100644 --- a/libs/labelbox/tests/integration/test_user_and_org.py +++ b/libs/labelbox/tests/integration/test_user_and_org.py @@ -20,4 +20,4 @@ def test_user_and_org_projects(client, project): org_project = org.projects(where=Project.uid == project.uid) assert user_project - assert org_project \ No newline at end of file + assert org_project diff --git a/libs/labelbox/tests/integration/test_user_management.py b/libs/labelbox/tests/integration/test_user_management.py index ca4328f51..cfdf3c566 100644 --- a/libs/labelbox/tests/integration/test_user_management.py +++ b/libs/labelbox/tests/integration/test_user_management.py @@ -8,14 +8,17 @@ @pytest.fixture def org_invite(client, organization, environ, queries): - role = client.get_roles()['LABELER'] + role = client.get_roles()["LABELER"] - dummy_email = "none+{}@labelbox.com".format("".join( - faker.random_letters(26))) + dummy_email = "none+{}@labelbox.com".format( + "".join(faker.random_letters(26)) + ) invite_limit = organization.invite_limit() if environ.value == "prod": - assert invite_limit.remaining > 0, "No invites available for the account associated with this key." + assert ( + invite_limit.remaining > 0 + ), "No invites available for the account associated with this key." elif environ.value != "staging": # Cannot run against local return @@ -31,26 +34,29 @@ def org_invite(client, organization, environ, queries): def project_role_1(client, project_pack): project_1, _ = project_pack roles = client.get_roles() - return ProjectRole(project=project_1, role=roles['LABELER']) + return ProjectRole(project=project_1, role=roles["LABELER"]) @pytest.fixture def project_role_2(client, project_pack): _, project_2 = project_pack roles = client.get_roles() - return ProjectRole(project=project_2, role=roles['REVIEWER']) + return ProjectRole(project=project_2, role=roles["REVIEWER"]) @pytest.fixture -def create_project_invite(client, organization, project_pack, queries, - project_role_1, project_role_2): +def create_project_invite( + client, organization, project_pack, queries, project_role_1, project_role_2 +): roles = client.get_roles() - dummy_email = "none+{}@labelbox.com".format("".join( - faker.random_letters(26))) + dummy_email = "none+{}@labelbox.com".format( + "".join(faker.random_letters(26)) + ) invite = organization.invite_user( dummy_email, - roles['NONE'], - project_roles=[project_role_1, project_role_2]) + roles["NONE"], + project_roles=[project_role_1, project_role_2], + ) yield invite @@ -59,10 +65,9 @@ def create_project_invite(client, organization, project_pack, queries, def test_org_invite(client, organization, environ, queries, org_invite): invite, invite_limit = org_invite - role = client.get_roles()['LABELER'] + role = client.get_roles()["LABELER"] if environ.value == "prod": - invite_limit_after = organization.invite_limit() # One user added assert invite_limit.remaining - invite_limit_after.remaining == 1 @@ -75,7 +80,8 @@ def test_org_invite(client, organization, environ, queries, org_invite): if outstanding_invite.uid == invite.uid: in_list = True org_role = outstanding_invite.organization_role_name.lower() - assert org_role == role.name.lower( + assert ( + org_role == role.name.lower() ), "Role should be labeler. Found {org_role} " assert in_list, "Invite not found" @@ -85,44 +91,67 @@ def test_cancel_invite( organization, queries, ): - role = client.get_roles()['LABELER'] - dummy_email = "none+{}@labelbox.com".format("".join( - faker.random_letters(26))) + role = client.get_roles()["LABELER"] + dummy_email = "none+{}@labelbox.com".format( + "".join(faker.random_letters(26)) + ) invite = organization.invite_user(dummy_email, role) queries.cancel_invite(client, invite.uid) outstanding_invites = [i.uid for i in queries.get_invites(client)] assert invite.uid not in outstanding_invites -def test_project_invite(client, organization, project_pack, queries, - create_project_invite, project_role_1, project_role_2): +def test_project_invite( + client, + organization, + project_pack, + queries, + create_project_invite, + project_role_1, + project_role_2, +): create_project_invite project_1, _ = project_pack roles = client.get_roles() project_invite = next(queries.get_project_invites(client, project_1.uid)) - assert set([(proj_invite.project.uid, proj_invite.role.uid) - for proj_invite in project_invite.project_roles - ]) == set([(proj_role.project.uid, proj_role.role.uid) - for proj_role in [project_role_1, project_role_2]]) - - assert set([(proj_invite.project.uid, proj_invite.role.uid) - for proj_invite in project_invite.project_roles - ]) == set([(proj_role.project.uid, proj_role.role.uid) - for proj_role in [project_role_1, project_role_2]]) + assert set( + [ + (proj_invite.project.uid, proj_invite.role.uid) + for proj_invite in project_invite.project_roles + ] + ) == set( + [ + (proj_role.project.uid, proj_role.role.uid) + for proj_role in [project_role_1, project_role_2] + ] + ) + + assert set( + [ + (proj_invite.project.uid, proj_invite.role.uid) + for proj_invite in project_invite.project_roles + ] + ) == set( + [ + (proj_role.project.uid, proj_role.role.uid) + for proj_role in [project_role_1, project_role_2] + ] + ) project_members = project_1.members() project_member = [ - member for member in project_members + member + for member in project_members if member.user().uid == client.get_user().uid ] assert len(project_member) == 1 project_member = project_member[0] - assert project_member.access_from == 'ORGANIZATION' - assert project_member.role().name.upper() == roles['ADMIN'].name.upper() + assert project_member.access_from == "ORGANIZATION" + assert project_member.role().name.upper() == roles["ADMIN"].name.upper() @pytest.mark.skip( @@ -131,8 +160,7 @@ def test_project_invite(client, organization, project_pack, queries, def test_member_management(client, organization, project, project_based_user): roles = client.get_roles() assert not len(list(project_based_user.projects())) - for role in [roles['LABELER'], roles['REVIEWER']]: - + for role in [roles["LABELER"], roles["REVIEWER"]]: project_based_user.upsert_project_role(project, role=role) members = project.members() is_member = False @@ -148,11 +176,14 @@ def test_member_management(client, organization, project, project_based_user): for member in project.members(): assert member.user().uid != project_based_user.uid - assert project_based_user.org_role().name.upper( - ) == roles['NONE'].name.upper() + assert ( + project_based_user.org_role().name.upper() == roles["NONE"].name.upper() + ) for role in [ - roles['TEAM_MANAGER'], roles['ADMIN'], roles['LABELER'], - roles['REVIEWER'] + roles["TEAM_MANAGER"], + roles["ADMIN"], + roles["LABELER"], + roles["REVIEWER"], ]: project_based_user.update_org_role(role) project_based_user.org_role().name.upper() == role.name.upper() diff --git a/libs/labelbox/tests/integration/test_webhook.py b/libs/labelbox/tests/integration/test_webhook.py index 25c8c667a..b93255c4e 100644 --- a/libs/labelbox/tests/integration/test_webhook.py +++ b/libs/labelbox/tests/integration/test_webhook.py @@ -25,19 +25,25 @@ def test_webhook_create_update(project, rand_gen): with pytest.raises(ValueError) as exc_info: webhook.update(status="invalid..") valid_webhook_statuses = {item.value for item in Webhook.Status} - assert str(exc_info.value) == \ - f"Value `invalid..` does not exist in supported values. Expected one of {valid_webhook_statuses}" + assert ( + str(exc_info.value) + == f"Value `invalid..` does not exist in supported values. Expected one of {valid_webhook_statuses}" + ) with pytest.raises(ValueError) as exc_info: webhook.update(topics=["invalid.."]) valid_webhook_topics = {item.value for item in Webhook.Topic} - assert str(exc_info.value) == \ - f"Value `invalid..` does not exist in supported values. Expected one of {valid_webhook_topics}" + assert ( + str(exc_info.value) + == f"Value `invalid..` does not exist in supported values. Expected one of {valid_webhook_topics}" + ) with pytest.raises(TypeError) as exc_info: webhook.update(topics="invalid..") - assert str(exc_info.value) == \ - "Topics must be List[Webhook.Topic]. Found `invalid..`" + assert ( + str(exc_info.value) + == "Topics must be List[Webhook.Topic]. Found `invalid..`" + ) webhook.delete() @@ -50,8 +56,7 @@ def test_webhook_create_with_no_secret(project, rand_gen): with pytest.raises(ValueError) as exc_info: Webhook.create(client, topics, url, secret, project) - assert str(exc_info.value) == \ - "Secret must be a non-empty string." + assert str(exc_info.value) == "Secret must be a non-empty string." def test_webhook_create_with_no_topics(project, rand_gen): @@ -62,8 +67,7 @@ def test_webhook_create_with_no_topics(project, rand_gen): with pytest.raises(ValueError) as exc_info: Webhook.create(client, topics, url, secret, project) - assert str(exc_info.value) == \ - "Topics must be a non-empty list." + assert str(exc_info.value) == "Topics must be a non-empty list." def test_webhook_create_with_no_url(project, rand_gen): @@ -74,5 +78,4 @@ def test_webhook_create_with_no_url(project, rand_gen): with pytest.raises(ValueError) as exc_info: Webhook.create(client, topics, url, secret, project) - assert str(exc_info.value) == \ - "URL must be a non-empty string." + assert str(exc_info.value) == "URL must be a non-empty string." diff --git a/libs/labelbox/tests/unit/conftest.py b/libs/labelbox/tests/unit/conftest.py index 0e8de8185..603fa9908 100644 --- a/libs/labelbox/tests/unit/conftest.py +++ b/libs/labelbox/tests/unit/conftest.py @@ -6,40 +6,25 @@ def ndjson_content(): line = """{"uuid": "9fd9a92e-2560-4e77-81d4-b2e955800092", "schemaId": "ckaeasyfk004y0y7wyye5epgu", "dataRow": {"id": "ck7kftpan8ir008910yf07r9c"}, "bbox": {"top": 48, "left": 58, "height": 865, "width": 1512}} {"uuid": "29b878f3-c2b4-4dbf-9f22-a795f0720125", "schemaId": "ckapgvrl7007q0y7ujkjkaaxt", "dataRow": {"id": "ck7kftpan8ir008910yf07r9c"}, "polygon": [{"x": 147.692, "y": 118.154}, {"x": 142.769, "y": 404.923}, {"x": 57.846, "y": 318.769}, {"x": 28.308, "y": 169.846}]}""" - expected_objects = [{ - 'uuid': '9fd9a92e-2560-4e77-81d4-b2e955800092', - 'schemaId': 'ckaeasyfk004y0y7wyye5epgu', - 'dataRow': { - 'id': 'ck7kftpan8ir008910yf07r9c' + expected_objects = [ + { + "uuid": "9fd9a92e-2560-4e77-81d4-b2e955800092", + "schemaId": "ckaeasyfk004y0y7wyye5epgu", + "dataRow": {"id": "ck7kftpan8ir008910yf07r9c"}, + "bbox": {"top": 48, "left": 58, "height": 865, "width": 1512}, }, - 'bbox': { - 'top': 48, - 'left': 58, - 'height': 865, - 'width': 1512 - } - }, { - 'uuid': - '29b878f3-c2b4-4dbf-9f22-a795f0720125', - 'schemaId': - 'ckapgvrl7007q0y7ujkjkaaxt', - 'dataRow': { - 'id': 'ck7kftpan8ir008910yf07r9c' + { + "uuid": "29b878f3-c2b4-4dbf-9f22-a795f0720125", + "schemaId": "ckapgvrl7007q0y7ujkjkaaxt", + "dataRow": {"id": "ck7kftpan8ir008910yf07r9c"}, + "polygon": [ + {"x": 147.692, "y": 118.154}, + {"x": 142.769, "y": 404.923}, + {"x": 57.846, "y": 318.769}, + {"x": 28.308, "y": 169.846}, + ], }, - 'polygon': [{ - 'x': 147.692, - 'y': 118.154 - }, { - 'x': 142.769, - 'y': 404.923 - }, { - 'x': 57.846, - 'y': 318.769 - }, { - 'x': 28.308, - 'y': 169.846 - }] - }] + ] return line, expected_objects @@ -47,65 +32,55 @@ def ndjson_content(): @pytest.fixture def ndjson_content_with_nonascii_and_line_breaks(): line = '{"id": "2489651127", "type": "PushEvent", "actor": {"id": 1459915, "login": "xtuaok", "gravatar_id": "", "url": "https://api.github.com/users/xtuaok", "avatar_url": "https://avatars.githubusercontent.com/u/1459915?"}, "repo": {"id": 6719841, "name": "xtuaok/twitter_track_following", "url": "https://api.github.com/repos/xtuaok/twitter_track_following"}, "payload": {"push_id": 536864008, "size": 1, "distinct_size": 1, "ref": "refs/heads/xtuaok", "head": "afb8afe306c7893d93d383a06e4d9df53b41bf47", "before": "4671b4868f1a060f2ed64d8268cd22d514a84e63", "commits": [{"sha": "afb8afe306c7893d93d383a06e4d9df53b41bf47", "author": {"email": "47cb89439b2d6961b59dff4298e837f67aa77389@gmail.com", "name": "Tomonori Tamagawa"}, "message": "Update ID 949438177,, - screen_name: chomado, - name: ちょまど@初詣おみくじ凶, - description: ( *゚▽゚* っ)З腐女子!絵描き!| H26新卒文系SE (入社して4ヶ月目の8月にSIer(適応障害になった)を辞職し開発者に転職) | H26秋応用情報合格!| 自作bot (in PHP) chomado_bot | プログラミングガチ初心者, - location:", "distinct": true, "url": "https://api.github.com/repos/xtuaok/twitter_track_following/commits/afb8afe306c7893d93d383a06e4d9df53b41bf47"}]}, "public": true, "created_at": "2015-01-01T15:00:10Z"}' - expected_objects = [{ - 'id': '2489651127', - 'type': 'PushEvent', - 'actor': { - 'id': 1459915, - 'login': 'xtuaok', - 'gravatar_id': '', - 'url': 'https://api.github.com/users/xtuaok', - 'avatar_url': 'https://avatars.githubusercontent.com/u/1459915?' - }, - 'repo': { - 'id': 6719841, - 'name': 'xtuaok/twitter_track_following', - 'url': 'https://api.github.com/repos/xtuaok/twitter_track_following' - }, - 'payload': { - 'push_id': - 536864008, - 'size': - 1, - 'distinct_size': - 1, - 'ref': - 'refs/heads/xtuaok', - 'head': - 'afb8afe306c7893d93d383a06e4d9df53b41bf47', - 'before': - '4671b4868f1a060f2ed64d8268cd22d514a84e63', - 'commits': [{ - 'sha': - 'afb8afe306c7893d93d383a06e4d9df53b41bf47', - 'author': { - 'email': - '47cb89439b2d6961b59dff4298e837f67aa77389@gmail.com', - 'name': - 'Tomonori Tamagawa' - }, - 'message': - 'Update ID 949438177,, - screen_name: chomado, - name: ちょまど@初詣おみくじ凶, - description: ( *゚▽゚* っ)З腐女子!絵描き!| H26新卒文系SE (入社して4ヶ月目の8月にSIer(適応障害になった)を辞職し開発者に転職) | H26秋応用情報合格!| 自作bot (in PHP) chomado_bot | プログラミングガチ初心者, - location:', - 'distinct': - True, - 'url': - 'https://api.github.com/repos/xtuaok/twitter_track_following/commits/afb8afe306c7893d93d383a06e4d9df53b41bf47' - }] - }, - 'public': True, - 'created_at': '2015-01-01T15:00:10Z' - }] + expected_objects = [ + { + "id": "2489651127", + "type": "PushEvent", + "actor": { + "id": 1459915, + "login": "xtuaok", + "gravatar_id": "", + "url": "https://api.github.com/users/xtuaok", + "avatar_url": "https://avatars.githubusercontent.com/u/1459915?", + }, + "repo": { + "id": 6719841, + "name": "xtuaok/twitter_track_following", + "url": "https://api.github.com/repos/xtuaok/twitter_track_following", + }, + "payload": { + "push_id": 536864008, + "size": 1, + "distinct_size": 1, + "ref": "refs/heads/xtuaok", + "head": "afb8afe306c7893d93d383a06e4d9df53b41bf47", + "before": "4671b4868f1a060f2ed64d8268cd22d514a84e63", + "commits": [ + { + "sha": "afb8afe306c7893d93d383a06e4d9df53b41bf47", + "author": { + "email": "47cb89439b2d6961b59dff4298e837f67aa77389@gmail.com", + "name": "Tomonori Tamagawa", + }, + "message": "Update ID 949438177,, - screen_name: chomado, - name: ちょまど@初詣おみくじ凶, - description: ( *゚▽゚* っ)З腐女子!絵描き!| H26新卒文系SE (入社して4ヶ月目の8月にSIer(適応障害になった)を辞職し開発者に転職) | H26秋応用情報合格!| 自作bot (in PHP) chomado_bot | プログラミングガチ初心者, - location:", + "distinct": True, + "url": "https://api.github.com/repos/xtuaok/twitter_track_following/commits/afb8afe306c7893d93d383a06e4d9df53b41bf47", + } + ], + }, + "public": True, + "created_at": "2015-01-01T15:00:10Z", + } + ] return line, expected_objects @pytest.fixture def generate_random_ndjson(rand_gen): - def _generate_random_ndjson(lines: int = 10): return [ - json.dumps({"data_row": { - "id": rand_gen(str) - }}) for _ in range(lines) + json.dumps({"data_row": {"id": rand_gen(str)}}) + for _ in range(lines) ] return _generate_random_ndjson @@ -113,9 +88,7 @@ def _generate_random_ndjson(lines: int = 10): @pytest.fixture def mock_response(): - class MockResponse: - def __init__(self, text: str, exception: Exception = None) -> None: self._text = text self._exception = exception diff --git a/libs/labelbox/tests/unit/export_task/test_export_task.py b/libs/labelbox/tests/unit/export_task/test_export_task.py index 50f08191b..ac84a875b 100644 --- a/libs/labelbox/tests/unit/export_task/test_export_task.py +++ b/libs/labelbox/tests/unit/export_task/test_export_task.py @@ -6,9 +6,8 @@ class TestExportTask: - def test_export_task(self): - with patch('requests.get') as mock_requests_get: + with patch("requests.get") as mock_requests_get: mock_task = MagicMock() mock_task.client.execute.side_effect = [ { @@ -16,15 +15,9 @@ def test_export_task(self): "exportMetadataHeader": { "total_size": 1, "total_lines": 1, - "lines": { - "start": 0, - "end": 1 - }, - "offsets": { - "start": 0, - "end": 0 - }, - "file": "file" + "lines": {"start": 0, "end": 1}, + "offsets": {"start": 0, "end": 0}, + "file": "file", } } }, @@ -33,15 +26,9 @@ def test_export_task(self): "exportFileFromOffset": { "total_size": 1, "total_lines": 1, - "lines": { - "start": 0, - "end": 1 - }, - "offsets": { - "start": 0, - "end": 0 - }, - "file": "file" + "lines": {"start": 0, "end": 1}, + "offsets": {"start": 0, "end": 0}, + "file": "file", } } }, @@ -49,8 +36,7 @@ def test_export_task(self): mock_task.status = "COMPLETE" data = { "data_row": { - "raw_data": - """ + "raw_data": """ {"raw_text":"}{"} {"raw_text":"\\nbad"} """ @@ -76,7 +62,7 @@ def test_get_buffered_stream_failed(self): export_task.get_buffered_stream() def test_get_buffered_stream(self): - with patch('requests.get') as mock_requests_get: + with patch("requests.get") as mock_requests_get: mock_task = MagicMock() mock_task.client.execute.side_effect = [ { @@ -84,15 +70,9 @@ def test_get_buffered_stream(self): "exportMetadataHeader": { "total_size": 1, "total_lines": 1, - "lines": { - "start": 0, - "end": 1 - }, - "offsets": { - "start": 0, - "end": 0 - }, - "file": "file" + "lines": {"start": 0, "end": 1}, + "offsets": {"start": 0, "end": 0}, + "file": "file", } } }, @@ -101,15 +81,9 @@ def test_get_buffered_stream(self): "exportFileFromOffset": { "total_size": 1, "total_lines": 1, - "lines": { - "start": 0, - "end": 1 - }, - "offsets": { - "start": 0, - "end": 0 - }, - "file": "file" + "lines": {"start": 0, "end": 1}, + "offsets": {"start": 0, "end": 0}, + "file": "file", } } }, @@ -117,8 +91,7 @@ def test_get_buffered_stream(self): mock_task.status = "COMPLETE" data = { "data_row": { - "raw_data": - """ + "raw_data": """ {"raw_text":"}{"} {"raw_text":"\\nbad"} """ @@ -128,11 +101,13 @@ def test_get_buffered_stream(self): mock_requests_get.return_value.content = "b" export_task = ExportTask(mock_task, is_export_v2=True) output_data = [] - export_task.get_buffered_stream().start(stream_handler=lambda x: output_data.append(x.json)) + export_task.get_buffered_stream().start( + stream_handler=lambda x: output_data.append(x.json) + ) assert data == output_data[0] def test_export_task_bad_offsets(self): - with patch('requests.get') as mock_requests_get: + with patch("requests.get") as mock_requests_get: mock_task = MagicMock() mock_task.client.execute.side_effect = [ { @@ -140,15 +115,9 @@ def test_export_task_bad_offsets(self): "exportMetadataHeader": { "total_size": 1, "total_lines": 1, - "lines": { - "start": 0, - "end": 1 - }, - "offsets": { - "start": 0, - "end": 0 - }, - "file": "file" + "lines": {"start": 0, "end": 1}, + "offsets": {"start": 0, "end": 0}, + "file": "file", } } }, @@ -157,15 +126,9 @@ def test_export_task_bad_offsets(self): "exportFileFromOffset": { "total_size": 1, "total_lines": 1, - "lines": { - "start": 0, - "end": 1 - }, - "offsets": { - "start": 0, - "end": 0 - }, - "file": "file" + "lines": {"start": 0, "end": 1}, + "offsets": {"start": 0, "end": 0}, + "file": "file", } } }, @@ -173,21 +136,17 @@ def test_export_task_bad_offsets(self): mock_task.status = "COMPLETE" data = { "data_row": { - "id": - "clwb6wvpv3mpx0712aafl9m00", - "external_id": - "43cdad5e-1fcf-450d-ad72-df4460edf973", - "global_key": - "9ab56c5a-5c2f-45ae-8e21-e53eb415cefe", - "row_data": - "{\"type\":\"application/vnd.labelbox.conversational\",\"version\":1,\"messages\":[{\"messageId\":\"message-0\",\"timestampUsec\":1530718491,\"content\":\"The minimum value of $3 \\\\cos x + 4 \\\\sin x + 8$ is\",\"user\":{\"userId\":\"prompt\",\"name\":\"prompt\"},\"align\":\"left\",\"canLabel\":true}],\"modelOutputs\":[{\"title\":\"Response 1\",\"content\":\"To find the minimum value of the expression $3 \\\\cos x + 4 \\\\sin x + 8$, we can use the fact that $a\\\\cos x+b\\\\sin x=\\\\sqrt{a^2+b^2}\\\\left(\\\\frac{a}{\\\\sqrt{a^2+b^2}}\\\\cos x+\\\\frac{b}{\\\\sqrt{a^2+b^2}}\\\\sin x\\\\right)$. This allows us to rewrite the expression as:\\n\\n$3\\\\cos x+4\\\\sin x+8=\\\\sqrt{3^2+4^2}\\\\left(\\\\frac{3}{\\\\sqrt{3^2+4^2}}\\\\cos x+\\\\frac{4}{\\\\sqrt{3^2+4^2}}\\\\sin x\\\\right)+8=5\\\\left(\\\\frac{3}{5}\\\\cos x+\\\\frac{4}{5}\\\\sin x\\\\right)+8$\\n\\nNow, let's consider the expression $\\\\frac{3}{5}\\\\cos x+\\\\frac{4}{5}\\\\sin x$. Since $\\\\left(\\\\frac{3}{5}\\\\right)^2+\\\\left(\\\\frac{4}{5}\\\\right)^2=1$, we can write $\\\\frac{3}{5}=\\\\cos\\\\theta$ and $\\\\frac{4}{5}=\\\\sin\\\\theta$ for some angle $\\\\theta$. Then:\\n\\n$\\\\frac{3}{5}\\\\cos x+\\\\frac{4}{5}\\\\sin x=\\\\cos\\\\theta\\\\cos x+\\\\sin\\\\theta\\\\sin x=\\\\cos(x-\\\\theta)$\\n\\nSo, the original expression can be written as:\\n\\n$5\\\\cos(x-\\\\theta)+8$\\n\\nSince the minimum value of $\\\\cos(x-\\\\theta)$ is $-1$, the minimum value of the original expression is:\\n\\n$5(-1)+8=-5+8=3$\\n\\nTherefore, the minimum value of $3\\\\cos x + 4\\\\sin x + 8$ is $\\\\boxed{3}$.\",\"modelConfigName\":\"null\"},{\"title\":\"Response 2\",\"content\":\"A nice math question!\\n\\nTo find the minimum value of $3 \\\\cos x + 4 \\\\sin x + 8$, we can use the fact that $a\\\\cos x + b\\\\sin x = \\\\sqrt{a^2 + b^2} \\\\cos(x - \\\\alpha)$, where $\\\\alpha = \\\\tan^{-1}\\\\left(\\\\frac{b}{a}\\\\right)$.\\n\\nIn this case, $a = 3$ and $b = 4$, so $\\\\alpha = \\\\tan^{-1}\\\\left(\\\\frac{4}{3}\\\\right)$.\\n\\nSo, we have:\\n\\n$$3 \\\\cos x + 4 \\\\sin x + 8 = \\\\sqrt{3^2 + 4^2} \\\\cos(x - \\\\alpha) + 8 = 5 \\\\cos(x - \\\\alpha) + 8$$\\n\\nNow, the minimum value of $\\\\cos(x - \\\\alpha)$ is $-1$, so the minimum value of $5 \\\\cos(x - \\\\alpha) + 8$ is:\\n\\n$$5(-1) + 8 = -5 + 8 = 3$$\\n\\nTherefore, the minimum value of $3 \\\\cos x + 4 \\\\sin x + 8$ is $\\\\boxed{3}$.\\n\\nLet me know if you have any questions or need further clarification!\",\"modelConfigName\":\"null\"}]}" + "id": "clwb6wvpv3mpx0712aafl9m00", + "external_id": "43cdad5e-1fcf-450d-ad72-df4460edf973", + "global_key": "9ab56c5a-5c2f-45ae-8e21-e53eb415cefe", + "row_data": '{"type":"application/vnd.labelbox.conversational","version":1,"messages":[{"messageId":"message-0","timestampUsec":1530718491,"content":"The minimum value of $3 \\\\cos x + 4 \\\\sin x + 8$ is","user":{"userId":"prompt","name":"prompt"},"align":"left","canLabel":true}],"modelOutputs":[{"title":"Response 1","content":"To find the minimum value of the expression $3 \\\\cos x + 4 \\\\sin x + 8$, we can use the fact that $a\\\\cos x+b\\\\sin x=\\\\sqrt{a^2+b^2}\\\\left(\\\\frac{a}{\\\\sqrt{a^2+b^2}}\\\\cos x+\\\\frac{b}{\\\\sqrt{a^2+b^2}}\\\\sin x\\\\right)$. This allows us to rewrite the expression as:\\n\\n$3\\\\cos x+4\\\\sin x+8=\\\\sqrt{3^2+4^2}\\\\left(\\\\frac{3}{\\\\sqrt{3^2+4^2}}\\\\cos x+\\\\frac{4}{\\\\sqrt{3^2+4^2}}\\\\sin x\\\\right)+8=5\\\\left(\\\\frac{3}{5}\\\\cos x+\\\\frac{4}{5}\\\\sin x\\\\right)+8$\\n\\nNow, let\'s consider the expression $\\\\frac{3}{5}\\\\cos x+\\\\frac{4}{5}\\\\sin x$. Since $\\\\left(\\\\frac{3}{5}\\\\right)^2+\\\\left(\\\\frac{4}{5}\\\\right)^2=1$, we can write $\\\\frac{3}{5}=\\\\cos\\\\theta$ and $\\\\frac{4}{5}=\\\\sin\\\\theta$ for some angle $\\\\theta$. Then:\\n\\n$\\\\frac{3}{5}\\\\cos x+\\\\frac{4}{5}\\\\sin x=\\\\cos\\\\theta\\\\cos x+\\\\sin\\\\theta\\\\sin x=\\\\cos(x-\\\\theta)$\\n\\nSo, the original expression can be written as:\\n\\n$5\\\\cos(x-\\\\theta)+8$\\n\\nSince the minimum value of $\\\\cos(x-\\\\theta)$ is $-1$, the minimum value of the original expression is:\\n\\n$5(-1)+8=-5+8=3$\\n\\nTherefore, the minimum value of $3\\\\cos x + 4\\\\sin x + 8$ is $\\\\boxed{3}$.","modelConfigName":"null"},{"title":"Response 2","content":"A nice math question!\\n\\nTo find the minimum value of $3 \\\\cos x + 4 \\\\sin x + 8$, we can use the fact that $a\\\\cos x + b\\\\sin x = \\\\sqrt{a^2 + b^2} \\\\cos(x - \\\\alpha)$, where $\\\\alpha = \\\\tan^{-1}\\\\left(\\\\frac{b}{a}\\\\right)$.\\n\\nIn this case, $a = 3$ and $b = 4$, so $\\\\alpha = \\\\tan^{-1}\\\\left(\\\\frac{4}{3}\\\\right)$.\\n\\nSo, we have:\\n\\n$$3 \\\\cos x + 4 \\\\sin x + 8 = \\\\sqrt{3^2 + 4^2} \\\\cos(x - \\\\alpha) + 8 = 5 \\\\cos(x - \\\\alpha) + 8$$\\n\\nNow, the minimum value of $\\\\cos(x - \\\\alpha)$ is $-1$, so the minimum value of $5 \\\\cos(x - \\\\alpha) + 8$ is:\\n\\n$$5(-1) + 8 = -5 + 8 = 3$$\\n\\nTherefore, the minimum value of $3 \\\\cos x + 4 \\\\sin x + 8$ is $\\\\boxed{3}$.\\n\\nLet me know if you have any questions or need further clarification!","modelConfigName":"null"}]}', }, "media_attributes": { "asset_type": "conversational", "mime_type": "application/vnd.labelbox.conversational", "labelable_ids": ["message-0"], - "message_count": 1 - } + "message_count": 1, + }, } mock_requests_get.return_value.text = json.dumps(data) mock_requests_get.return_value.content = "b" diff --git a/libs/labelbox/tests/unit/export_task/test_unit_file_converter.py b/libs/labelbox/tests/unit/export_task/test_unit_file_converter.py index 3f3af9521..81e9eb60f 100644 --- a/libs/labelbox/tests/unit/export_task/test_unit_file_converter.py +++ b/libs/labelbox/tests/unit/export_task/test_unit_file_converter.py @@ -12,7 +12,6 @@ class TestFileConverter: - def test_with_correct_ndjson(self, tmp_path, generate_random_ndjson): directory = tmp_path / "file-converter" directory.mkdir() @@ -24,8 +23,9 @@ def test_with_correct_ndjson(self, tmp_path, generate_random_ndjson): client=MagicMock(), task_id="task-id", stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), + metadata_header=_MetadataHeader( + total_size=len(file_content), total_lines=line_count + ), ), file_info=_MetadataFileInfo( offsets=Range(start=0, end=len(file_content) - 1), @@ -55,8 +55,9 @@ def test_with_no_newline_at_end(self, tmp_path, generate_random_ndjson): client=MagicMock(), task_id="task-id", stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), + metadata_header=_MetadataHeader( + total_size=len(file_content), total_lines=line_count + ), ), file_info=_MetadataFileInfo( offsets=Range(start=0, end=len(file_content) - 1), diff --git a/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_line.py b/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_line.py index 1dba056fa..37c93647e 100644 --- a/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_line.py +++ b/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_line.py @@ -8,7 +8,6 @@ class TestFileRetrieverByLine: - def test_by_line_from_start(self, generate_random_ndjson, mock_response): line_count = 10 ndjson = generate_random_ndjson(line_count) @@ -19,25 +18,21 @@ def test_by_line_from_start(self, generate_random_ndjson, mock_response): return_value={ "task": { "exportFileFromLine": { - "offsets": { - "start": "0", - "end": len(file_content) - 1 - }, - "lines": { - "start": "0", - "end": str(line_count - 1) - }, + "offsets": {"start": "0", "end": len(file_content) - 1}, + "lines": {"start": "0", "end": str(line_count - 1)}, "file": "http://some-url.com/file.ndjson", } } - }) + } + ) mock_ctx = _TaskContext( client=mock_client, task_id="task-id", stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), + metadata_header=_MetadataHeader( + total_size=len(file_content), total_lines=line_count + ), ) with patch("requests.get", return_value=mock_response(file_content)): @@ -60,25 +55,21 @@ def test_by_line_from_middle(self, generate_random_ndjson, mock_response): return_value={ "task": { "exportFileFromLine": { - "offsets": { - "start": "0", - "end": len(file_content) - 1 - }, - "lines": { - "start": "0", - "end": str(line_count - 1) - }, + "offsets": {"start": "0", "end": len(file_content) - 1}, + "lines": {"start": "0", "end": str(line_count - 1)}, "file": "http://some-url.com/file.ndjson", } } - }) + } + ) mock_ctx = _TaskContext( client=mock_client, task_id="task-id", stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), + metadata_header=_MetadataHeader( + total_size=len(file_content), total_lines=line_count + ), ) line_start = 5 @@ -104,25 +95,21 @@ def test_by_line_from_last(self, generate_random_ndjson, mock_response): return_value={ "task": { "exportFileFromLine": { - "offsets": { - "start": "0", - "end": len(file_content) - 1 - }, - "lines": { - "start": "0", - "end": str(line_count - 1) - }, + "offsets": {"start": "0", "end": len(file_content) - 1}, + "lines": {"start": "0", "end": str(line_count - 1)}, "file": "http://some-url.com/file.ndjson", } } - }) + } + ) mock_ctx = _TaskContext( client=mock_client, task_id="task-id", stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), + metadata_header=_MetadataHeader( + total_size=len(file_content), total_lines=line_count + ), ) line_start = 9 diff --git a/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_offset.py b/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_offset.py index 07271d31c..870e03307 100644 --- a/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_offset.py +++ b/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_offset.py @@ -8,7 +8,6 @@ class TestFileRetrieverByOffset: - def test_by_offset_from_start(self, generate_random_ndjson, mock_response): line_count = 10 ndjson = generate_random_ndjson(line_count) @@ -19,25 +18,21 @@ def test_by_offset_from_start(self, generate_random_ndjson, mock_response): return_value={ "task": { "exportFileFromOffset": { - "offsets": { - "start": "0", - "end": len(file_content) - 1 - }, - "lines": { - "start": "0", - "end": str(line_count - 1) - }, + "offsets": {"start": "0", "end": len(file_content) - 1}, + "lines": {"start": "0", "end": str(line_count - 1)}, "file": "http://some-url.com/file.ndjson", } } - }) + } + ) mock_ctx = _TaskContext( client=mock_client, task_id="task-id", stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), + metadata_header=_MetadataHeader( + total_size=len(file_content), total_lines=line_count + ), ) with patch("requests.get", return_value=mock_response(file_content)): @@ -60,25 +55,21 @@ def test_by_offset_from_middle(self, generate_random_ndjson, mock_response): return_value={ "task": { "exportFileFromOffset": { - "offsets": { - "start": "0", - "end": len(file_content) - 1 - }, - "lines": { - "start": "0", - "end": str(line_count - 1) - }, + "offsets": {"start": "0", "end": len(file_content) - 1}, + "lines": {"start": "0", "end": str(line_count - 1)}, "file": "http://some-url.com/file.ndjson", } } - }) + } + ) mock_ctx = _TaskContext( client=mock_client, task_id="task-id", stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), + metadata_header=_MetadataHeader( + total_size=len(file_content), total_lines=line_count + ), ) line_start = 5 diff --git a/libs/labelbox/tests/unit/export_task/test_unit_json_converter.py b/libs/labelbox/tests/unit/export_task/test_unit_json_converter.py index 249eff0f5..f5ccf26fb 100644 --- a/libs/labelbox/tests/unit/export_task/test_unit_json_converter.py +++ b/libs/labelbox/tests/unit/export_task/test_unit_json_converter.py @@ -1,10 +1,14 @@ from unittest.mock import MagicMock -from labelbox.schema.export_task import Converter, JsonConverter, Range, _MetadataFileInfo +from labelbox.schema.export_task import ( + Converter, + JsonConverter, + Range, + _MetadataFileInfo, +) class TestJsonConverter: - def test_with_correct_ndjson(self, generate_random_ndjson): line_count = 10 ndjson = generate_random_ndjson(line_count) @@ -71,8 +75,9 @@ def test_from_offset(self, generate_random_ndjson): for idx, output in enumerate(converter.convert(input_args)): assert output.current_line == line_start + idx assert output.current_offset == current_offset - assert output.json_str == ndjson[line_start + - idx][skipped_bytes:] + assert ( + output.json_str == ndjson[line_start + idx][skipped_bytes:] + ) current_offset += len(output.json_str) + 1 skipped_bytes = 0 @@ -100,7 +105,8 @@ def test_from_offset_last_line(self, generate_random_ndjson): for idx, output in enumerate(converter.convert(input_args)): assert output.current_line == line_start + idx assert output.current_offset == current_offset - assert output.json_str == ndjson[line_start + - idx][skipped_bytes:] + assert ( + output.json_str == ndjson[line_start + idx][skipped_bytes:] + ) current_offset += len(output.json_str) + 1 skipped_bytes = 0 diff --git a/libs/labelbox/tests/unit/schema/test_user_group.py b/libs/labelbox/tests/unit/schema/test_user_group.py index 0eb0381d6..65584f8ef 100644 --- a/libs/labelbox/tests/unit/schema/test_user_group.py +++ b/libs/labelbox/tests/unit/schema/test_user_group.py @@ -2,7 +2,13 @@ from collections import defaultdict from unittest.mock import MagicMock from labelbox import Client -from labelbox.exceptions import ResourceConflict, ResourceCreationError, ResourceNotFoundError, MalformedQueryException, UnprocessableEntityError +from labelbox.exceptions import ( + ResourceConflict, + ResourceCreationError, + ResourceNotFoundError, + MalformedQueryException, + UnprocessableEntityError, +) from labelbox.schema.project import Project from labelbox.schema.user import User from labelbox.schema.user_group import UserGroup, UserGroupColor @@ -10,6 +16,7 @@ from labelbox.schema.ontology_kind import EditorTaskType from labelbox.schema.media_type import MediaType + @pytest.fixture def group_user(): user_values = defaultdict(lambda: None) @@ -30,7 +37,6 @@ def group_project(): class TestUserGroupColor: - def test_user_group_color_values(self): assert UserGroupColor.BLUE.value == "9EC5FF" assert UserGroupColor.PURPLE.value == "CEB8FF" @@ -44,12 +50,11 @@ def test_user_group_color_values(self): class TestUserGroup: - def setup_method(self): self.client = MagicMock(Client) self.client.enable_experimental = True self.group = UserGroup(client=self.client) - + def test_constructor_experimental_needed(self): client = MagicMock(Client) client.enable_experimental = False @@ -74,36 +79,20 @@ def test_update_with_exception_name(self): def test_get(self): projects = [ - { - "id": "project_id_1", - "name": "project_1" - }, - { - "id": "project_id_2", - "name": "project_2" - } + {"id": "project_id_1", "name": "project_1"}, + {"id": "project_id_2", "name": "project_2"}, ] group_members = [ - { - "id": "user_id_1", - "email": "email_1" - }, - { - "id": "user_id_2", - "email": "email_2" - } + {"id": "user_id_1", "email": "email_1"}, + {"id": "user_id_2", "email": "email_2"}, ] self.client.execute.return_value = { "userGroup": { "id": "group_id", "name": "Test Group", "color": "4ED2F9", - "projects": { - "nodes": projects - }, - "members": { - "nodes": group_members - } + "projects": {"nodes": projects}, + "members": {"nodes": group_members}, } } group = UserGroup(self.client) @@ -135,8 +124,8 @@ def test_update(self, group_user, group_project): group.id = "group_id" group.name = "Test Group" group.color = UserGroupColor.BLUE - group.users = { group_user } - group.projects = { group_project } + group.users = {group_user} + group.projects = {group_project} updated_group = group.update() @@ -209,15 +198,11 @@ def test_create(self, group_user, group_project): group = self.group group.name = "New Group" group.color = UserGroupColor.PINK - group.users = { group_user } - group.projects = { group_project } + group.users = {group_user} + group.projects = {group_project} self.client.execute.return_value = { - "createUserGroup": { - "group": { - "id": "group_id" - } - } + "createUserGroup": {"group": {"id": "group_id"}} } created_group = group.create() execute = self.client.execute.call_args[0] @@ -237,7 +222,7 @@ def test_create(self, group_user, group_project): assert list(created_group.users)[0].uid == "user_id" assert len(created_group.projects) == 1 assert list(created_group.projects)[0].uid == "project_id" - + def test_create_resource_creation_error(self): self.client.execute.side_effect = ResourceConflict("Error") group = UserGroup(self.client) @@ -251,9 +236,7 @@ def test_delete(self): group.id = "group_id" self.client.execute.return_value = { - "deleteUserGroup": { - "success": True - } + "deleteUserGroup": {"success": True} } deleted = group.delete() execute = self.client.execute.call_args[0] @@ -287,75 +270,78 @@ def test_user_groups_empty(self): def test_user_groups(self): self.client.execute.return_value = { "userGroups": { - "nextCursor": - None, - "nodes": [{ - "id": "group_id_1", - "name": "Group 1", - "color": "9EC5FF", - "projects": { - "nodes": [{ - "id": "project_id_1", - "name": "Project 1" - }, { - "id": "project_id_2", - "name": "Project 2" - }] + "nextCursor": None, + "nodes": [ + { + "id": "group_id_1", + "name": "Group 1", + "color": "9EC5FF", + "projects": { + "nodes": [ + {"id": "project_id_1", "name": "Project 1"}, + {"id": "project_id_2", "name": "Project 2"}, + ] + }, + "members": { + "nodes": [ + { + "id": "user_id_1", + "email": "user1@example.com", + }, + { + "id": "user_id_2", + "email": "user2@example.com", + }, + ] + }, }, - "members": { - "nodes": [{ - "id": "user_id_1", - "email": "user1@example.com" - }, { - "id": "user_id_2", - "email": "user2@example.com" - }] - } - }, { - "id": "group_id_2", - "name": "Group 2", - "color": "9EC5FF", - "projects": { - "nodes": [{ - "id": "project_id_3", - "name": "Project 3" - }, { - "id": "project_id_4", - "name": "Project 4" - }] + { + "id": "group_id_2", + "name": "Group 2", + "color": "9EC5FF", + "projects": { + "nodes": [ + {"id": "project_id_3", "name": "Project 3"}, + {"id": "project_id_4", "name": "Project 4"}, + ] + }, + "members": { + "nodes": [ + { + "id": "user_id_3", + "email": "user3@example.com", + }, + { + "id": "user_id_4", + "email": "user4@example.com", + }, + ] + }, }, - "members": { - "nodes": [{ - "id": "user_id_3", - "email": "user3@example.com" - }, { - "id": "user_id_4", - "email": "user4@example.com" - }] - } - }, { - "id": "group_id_3", - "name": "Group 3", - "color": "9EC5FF", - "projects": { - "nodes": [{ - "id": "project_id_5", - "name": "Project 5" - }, { - "id": "project_id_6", - "name": "Project 6" - }] + { + "id": "group_id_3", + "name": "Group 3", + "color": "9EC5FF", + "projects": { + "nodes": [ + {"id": "project_id_5", "name": "Project 5"}, + {"id": "project_id_6", "name": "Project 6"}, + ] + }, + "members": { + "nodes": [ + { + "id": "user_id_5", + "email": "user5@example.com", + }, + { + "id": "user_id_6", + "email": "user6@example.com", + }, + ] + }, }, - "members": { - "nodes": [{ - "id": "user_id_5", - "email": "user5@example.com" - }, { - "id": "user_id_6", - "email": "user6@example.com" - }] - } - }] + ], } } @@ -389,4 +375,5 @@ def test_user_groups(self): if __name__ == "__main__": import subprocess + subprocess.call(["pytest", "-v", __file__]) diff --git a/libs/labelbox/tests/unit/test_annotation_import.py b/libs/labelbox/tests/unit/test_annotation_import.py index ff0835467..d4642f17b 100644 --- a/libs/labelbox/tests/unit/test_annotation_import.py +++ b/libs/labelbox/tests/unit/test_annotation_import.py @@ -10,69 +10,59 @@ def test_data_row_validation_errors(): "schemaId": "ckrb1sfl8099g0y91cxbd5ftb", }, "schemaId": "c123", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, + "dataRow": {"globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"}, }, { "answer": { "schemaId": "ckrb1sfl8099g0y91cxbd5ftb", }, "schemaId": "c123", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, + "dataRow": {"globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"}, }, { "answer": { "schemaId": "ckrb1sfl8099g0y91cxbd5ftb", }, "schemaId": "c123", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, + "dataRow": {"globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"}, }, { "answer": { "schemaId": "ckrb1sfl8099g0y91cxbd5ftb", }, "schemaId": "c123", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, + "dataRow": {"globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"}, }, { "answer": { "schemaId": "ckrb1sfl8099g0y91cxbd5ftb", }, "schemaId": "c123", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, + "dataRow": {"globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"}, }, ] # Set up data for validation errors # Invalid: Remove 'dataRow' part entirely - del predictions[0]['dataRow'] + del predictions[0]["dataRow"] # Invalid: Set both id and globalKey - predictions[1]['dataRow'] = { - 'id': 'some id', - 'globalKey': 'some global key' + predictions[1]["dataRow"] = { + "id": "some id", + "globalKey": "some global key", } # Invalid: Set both id and globalKey to None - predictions[2]['dataRow'] = {'id': None, 'globalKey': None} + predictions[2]["dataRow"] = {"id": None, "globalKey": None} # Valid - predictions[3]['dataRow'] = { - 'id': 'some id', + predictions[3]["dataRow"] = { + "id": "some id", } # Valid - predictions[4]['dataRow'] = { - 'globalKey': 'some global key', + predictions[4]["dataRow"] = { + "globalKey": "some global key", } with pytest.raises(ValueError) as exc_info: @@ -80,6 +70,12 @@ def test_data_row_validation_errors(): exception_str = str(exc_info.value) assert "Found 3 annotations with errors" in exception_str assert "'dataRow' is missing in" in exception_str - assert "Must provide only one of 'id' or 'globalKey' for 'dataRow'" in exception_str - assert "'dataRow': {'id': 'some id', 'globalKey': 'some global key'}" in exception_str + assert ( + "Must provide only one of 'id' or 'globalKey' for 'dataRow'" + in exception_str + ) + assert ( + "'dataRow': {'id': 'some id', 'globalKey': 'some global key'}" + in exception_str + ) assert "'dataRow': {'id': None, 'globalKey': None}" in exception_str diff --git a/libs/labelbox/tests/unit/test_data_row_upsert_data.py b/libs/labelbox/tests/unit/test_data_row_upsert_data.py index b8c68c0af..11cc4153f 100644 --- a/libs/labelbox/tests/unit/test_data_row_upsert_data.py +++ b/libs/labelbox/tests/unit/test_data_row_upsert_data.py @@ -1,32 +1,37 @@ from unittest.mock import MagicMock, patch import pytest -from labelbox.schema.internal.data_row_upsert_item import (DataRowUpsertItem, - DataRowCreateItem) +from labelbox.schema.internal.data_row_upsert_item import ( + DataRowUpsertItem, + DataRowCreateItem, +) from labelbox.schema.identifiable import UniqueId, GlobalKey from labelbox.schema.asset_attachment import AttachmentType from labelbox.schema.dataset import Dataset -from labelbox.schema.internal.descriptor_file_creator import DescriptorFileCreator +from labelbox.schema.internal.descriptor_file_creator import ( + DescriptorFileCreator, +) from labelbox.schema.data_row import DataRow @pytest.fixture def data_row_create_items(): - dataset_id = 'test_dataset' + dataset_id = "test_dataset" items = [ { "row_data": "http://my_site.com/photos/img_01.jpg", "global_key": "global_key1", "external_id": "ex_id1", - "attachments": [{ - "type": AttachmentType.RAW_TEXT, - "name": "att1", - "value": "test1" - }], - "metadata": [{ - "name": "tag", - "value": "tag value" - },] + "attachments": [ + { + "type": AttachmentType.RAW_TEXT, + "name": "att1", + "value": "test1", + } + ], + "metadata": [ + {"name": "tag", "value": "tag value"}, + ], }, ] return dataset_id, items @@ -34,7 +39,7 @@ def data_row_create_items(): @pytest.fixture def data_row_create_items_row_data_none(): - dataset_id = 'test_dataset' + dataset_id = "test_dataset" items = [ { "row_data": None, @@ -45,16 +50,10 @@ def data_row_create_items_row_data_none(): @pytest.fixture def data_row_update_items(): - dataset_id = 'test_dataset' + dataset_id = "test_dataset" items = [ - { - "key": GlobalKey("global_key1"), - "global_key": "global_key1_updated" - }, - { - "key": UniqueId('unique_id1'), - "external_id": "ex_id1_updated" - }, + {"key": GlobalKey("global_key1"), "global_key": "global_key1_updated"}, + {"key": UniqueId("unique_id1"), "external_id": "ex_id1_updated"}, ] return dataset_id, items @@ -84,22 +83,20 @@ def test_data_row_create_items_not_updateable(data_row_update_items): def test_upsert_is_empty(): - item = DataRowUpsertItem(id={ - "id": UniqueId, - "value": UniqueId("123") - }, - payload={}) + item = DataRowUpsertItem( + id={"id": UniqueId, "value": UniqueId("123")}, payload={} + ) assert item.is_empty() - item = DataRowUpsertItem(id={ - "id": UniqueId, - "value": UniqueId("123") - }, - payload={"dataset_id": "test_dataset"}) + item = DataRowUpsertItem( + id={"id": UniqueId, "value": UniqueId("123")}, + payload={"dataset_id": "test_dataset"}, + ) assert item.is_empty() item = DataRowUpsertItem( - id={}, payload={"row_data": "http://my_site.com/photos/img_01.jpg"}) + id={}, payload={"row_data": "http://my_site.com/photos/img_01.jpg"} + ) assert not item.is_empty() @@ -117,29 +114,26 @@ def test_create_is_empty(): assert item.is_empty() item = DataRowCreateItem( - id={}, payload={"row_data": "http://my_site.com/photos/img_01.jpg"}) + id={}, payload={"row_data": "http://my_site.com/photos/img_01.jpg"} + ) assert not item.is_empty() item = DataRowCreateItem( id={}, - payload={DataRow.row_data: "http://my_site.com/photos/img_01.jpg"}) + payload={DataRow.row_data: "http://my_site.com/photos/img_01.jpg"}, + ) assert not item.is_empty() legacy_converstational_data_payload = { - "externalId": - "Convo-123", - "type": - "application/vnd.labelbox.conversational", - "conversationalData": [{ - "messageId": - "message-0", - "content": - "I love iphone! i just bought new iphone! :smiling_face_with_3_hearts: :calling:", - "user": { - "userId": "Bot 002", - "name": "Bot" - }, - }] + "externalId": "Convo-123", + "type": "application/vnd.labelbox.conversational", + "conversationalData": [ + { + "messageId": "message-0", + "content": "I love iphone! i just bought new iphone! :smiling_face_with_3_hearts: :calling:", + "user": {"userId": "Bot 002", "name": "Bot"}, + } + ], } item = DataRowCreateItem(id={}, payload=legacy_converstational_data_payload) assert not item.is_empty() @@ -154,20 +148,25 @@ def test_create_row_data_none(): ] client = MagicMock() dataset = Dataset( - client, { - "id": 'test_dataset', - "name": 'test_dataset', + client, + { + "id": "test_dataset", + "name": "test_dataset", "createdAt": "2021-06-01T00:00:00.000Z", "description": "test_dataset", "updatedAt": "2021-06-01T00:00:00.000Z", "rowCount": 0, - }) - - with patch.object(DescriptorFileCreator, - 'create', - return_value=["http://bar.com/chunk_uri"]): - with pytest.raises(ValueError, - match="Some items have an empty payload"): + }, + ) + + with patch.object( + DescriptorFileCreator, + "create", + return_value=["http://bar.com/chunk_uri"], + ): + with pytest.raises( + ValueError, match="Some items have an empty payload" + ): dataset.create_data_rows(items) client.execute.assert_not_called() diff --git a/libs/labelbox/tests/unit/test_exceptions.py b/libs/labelbox/tests/unit/test_exceptions.py index 69bcfbd77..4602fb984 100644 --- a/libs/labelbox/tests/unit/test_exceptions.py +++ b/libs/labelbox/tests/unit/test_exceptions.py @@ -3,11 +3,18 @@ from labelbox.exceptions import error_message_for_unparsed_graphql_error -@pytest.mark.parametrize('exception_message, expected_result', [ - ("Unparsed errors on query execution: [{'message': 'Cannot create model config for project because model setup is complete'}]", - "Cannot create model config for project because model setup is complete"), - ("blah blah blah", "Unknown error"), -]) +@pytest.mark.parametrize( + "exception_message, expected_result", + [ + ( + "Unparsed errors on query execution: [{'message': 'Cannot create model config for project because model setup is complete'}]", + "Cannot create model config for project because model setup is complete", + ), + ("blah blah blah", "Unknown error"), + ], +) def test_client_unparsed_exception_messages(exception_message, expected_result): - assert error_message_for_unparsed_graphql_error( - exception_message) == expected_result + assert ( + error_message_for_unparsed_graphql_error(exception_message) + == expected_result + ) diff --git a/libs/labelbox/tests/unit/test_label_data_type.py b/libs/labelbox/tests/unit/test_label_data_type.py index 737136a36..7bc32e37c 100644 --- a/libs/labelbox/tests/unit/test_label_data_type.py +++ b/libs/labelbox/tests/unit/test_label_data_type.py @@ -2,35 +2,36 @@ import pytest from pydantic import ValidationError -from labelbox.data.annotation_types.data.generic_data_row_data import GenericDataRowData +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) from labelbox.data.annotation_types.data.video import VideoData from labelbox.data.annotation_types.label import Label def test_generic_data_type(): data = { - 'global_key': - 'https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr', + "global_key": "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr", } label = Label(data=data) data = label.data assert isinstance(data, GenericDataRowData) - assert data.global_key == 'https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr' + assert ( + data.global_key + == "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr" + ) def test_generic_data_type_validations(): data = { - 'row_data': - 'https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr', + "row_data": "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr", } with pytest.raises(ValueError, match="Exactly one of"): Label(data=data) data = { - 'uid': - "abcd", - 'global_key': - 'https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr', + "uid": "abcd", + "global_key": "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr", } with pytest.raises(ValueError, match="Only one of"): Label(data=data) @@ -38,22 +39,26 @@ def test_generic_data_type_validations(): def test_video_data_type(): data = { - 'global_key': - 'https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr', + "global_key": "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr", } with pytest.warns(UserWarning, match="Use a dict"): label = Label(data=VideoData(**data)) data = label.data assert isinstance(data, VideoData) - assert data.global_key == 'https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr' + assert ( + data.global_key + == "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr" + ) def test_generic_data_row(): data = { - 'global_key': - 'https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr', + "global_key": "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr", } label = Label(data=GenericDataRowData(**data)) data = label.data assert isinstance(data, GenericDataRowData) - assert data.global_key == 'https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr' + assert ( + data.global_key + == "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr" + ) diff --git a/libs/labelbox/tests/unit/test_mal_import.py b/libs/labelbox/tests/unit/test_mal_import.py index 799944a13..3dc3eea56 100644 --- a/libs/labelbox/tests/unit/test_mal_import.py +++ b/libs/labelbox/tests/unit/test_mal_import.py @@ -11,35 +11,32 @@ def test_should_warn_user_about_unsupported_confidence(): labels = [ { - "bbox": { - "height": 428, - "left": 2089, - "top": 1251, - "width": 158 - }, - "classifications": [{ - "answer": [{ - "schemaId": "ckrb1sfl8099e0y919v260awv", - "confidence": 0.894 - }], - "schemaId": "ckrb1sfkn099c0y910wbo0p1a" - }], - "dataRow": { - "id": "ckrb1sf1i1g7i0ybcdc6oc8ct" - }, + "bbox": {"height": 428, "left": 2089, "top": 1251, "width": 158}, + "classifications": [ + { + "answer": [ + { + "schemaId": "ckrb1sfl8099e0y919v260awv", + "confidence": 0.894, + } + ], + "schemaId": "ckrb1sfkn099c0y910wbo0p1a", + } + ], + "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, "schemaId": "ckrb1sfjx099a0y914hl319ie", - "uuid": "d009925d-91a3-4f67-abd9-753453f5a584" + "uuid": "d009925d-91a3-4f67-abd9-753453f5a584", }, ] - with patch.object(MALPredictionImport, '_create_mal_import_from_bytes'): - with patch.object(logger, 'warning') as warning_mock: - MALPredictionImport.create_from_objects(client=MagicMock(), - project_id=id, - name=id, - predictions=labels) + with patch.object(MALPredictionImport, "_create_mal_import_from_bytes"): + with patch.object(logger, "warning") as warning_mock: + MALPredictionImport.create_from_objects( + client=MagicMock(), project_id=id, name=id, predictions=labels + ) warning_mock.assert_called_once() "Confidence scores are not supported in MAL Prediction Import" in warning_mock.call_args_list[ - 0].args[0] + 0 + ].args[0] def test_invalid_labels_format(): @@ -47,29 +44,25 @@ def test_invalid_labels_format(): id = str(uuid.uuid4()) label = { - "bbox": { - "height": 428, - "left": 2089, - "top": 1251, - "width": 158 - }, - "classifications": [{ - "answer": [{ - "schemaId": "ckrb1sfl8099e0y919v260awv", - "confidence": 0.894 - }], - "schemaId": "ckrb1sfkn099c0y910wbo0p1a" - }], - "dataRow": { - "id": "ckrb1sf1i1g7i0ybcdc6oc8ct" - }, + "bbox": {"height": 428, "left": 2089, "top": 1251, "width": 158}, + "classifications": [ + { + "answer": [ + { + "schemaId": "ckrb1sfl8099e0y919v260awv", + "confidence": 0.894, + } + ], + "schemaId": "ckrb1sfkn099c0y910wbo0p1a", + } + ], + "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, "schemaId": "ckrb1sfjx099a0y914hl319ie", - "uuid": "3a83db52-75e0-49af-a171-234ce604502a" + "uuid": "3a83db52-75e0-49af-a171-234ce604502a", } - with patch.object(MALPredictionImport, '_create_mal_import_from_bytes'): + with patch.object(MALPredictionImport, "_create_mal_import_from_bytes"): with pytest.raises(TypeError): - MALPredictionImport.create_from_objects(client=MagicMock(), - project_id=id, - name=id, - predictions=label) + MALPredictionImport.create_from_objects( + client=MagicMock(), project_id=id, name=id, predictions=label + ) diff --git a/libs/labelbox/tests/unit/test_ndjson_parsing.py b/libs/labelbox/tests/unit/test_ndjson_parsing.py index 832e41928..508e44d74 100644 --- a/libs/labelbox/tests/unit/test_ndjson_parsing.py +++ b/libs/labelbox/tests/unit/test_ndjson_parsing.py @@ -15,7 +15,7 @@ def test_loads(ndjson_content): def test_loads_bytes(ndjson_content): expected_line, expected_objects = ndjson_content - bytes_line = expected_line.encode('utf-8') + bytes_line = expected_line.encode("utf-8") parsed_line = parser.loads(bytes_line) assert parsed_line == expected_objects diff --git a/libs/labelbox/tests/unit/test_project.py b/libs/labelbox/tests/unit/test_project.py index 367f74296..5e5f99c57 100644 --- a/libs/labelbox/tests/unit/test_project.py +++ b/libs/labelbox/tests/unit/test_project.py @@ -32,15 +32,21 @@ def project_entity(): @pytest.mark.parametrize( - 'api_editor_task_type, expected_editor_task_type', - [(None, EditorTaskType.Missing), - ('MODEL_CHAT_EVALUATION', EditorTaskType.ModelChatEvaluation), - ('RESPONSE_CREATION', EditorTaskType.ResponseCreation), - ('OFFLINE_MODEL_CHAT_EVALUATION', - EditorTaskType.OfflineModelChatEvaluation), - ('NEW_TYPE', EditorTaskType.Missing)]) -def test_project_editor_task_type(api_editor_task_type, - expected_editor_task_type, project_entity): + "api_editor_task_type, expected_editor_task_type", + [ + (None, EditorTaskType.Missing), + ("MODEL_CHAT_EVALUATION", EditorTaskType.ModelChatEvaluation), + ("RESPONSE_CREATION", EditorTaskType.ResponseCreation), + ( + "OFFLINE_MODEL_CHAT_EVALUATION", + EditorTaskType.OfflineModelChatEvaluation, + ), + ("NEW_TYPE", EditorTaskType.Missing), + ], +) +def test_project_editor_task_type( + api_editor_task_type, expected_editor_task_type, project_entity +): client = MagicMock() project = Project( client, diff --git a/libs/labelbox/tests/unit/test_unit_delete_batch_data_row_metadata.py b/libs/labelbox/tests/unit/test_unit_delete_batch_data_row_metadata.py index 561f8d6b0..cd6eadd79 100644 --- a/libs/labelbox/tests/unit/test_unit_delete_batch_data_row_metadata.py +++ b/libs/labelbox/tests/unit/test_unit_delete_batch_data_row_metadata.py @@ -7,50 +7,44 @@ def test_dict_delete_data_row_batch(): obj = _DeleteBatchDataRowMetadata( data_row_identifier=UniqueId("abcd"), - schema_ids=["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"]) + schema_ids=["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"], + ) assert obj.model_dump() == { - "data_row_identifier": { - "id": "abcd", - "id_type": "ID" - }, + "data_row_identifier": {"id": "abcd", "id_type": "ID"}, "schema_ids": [ - "clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy" - ] + "clqh77tyk000008l2a9mjesa1", + "clqh784br000008jy0yuq04fy", + ], } obj = _DeleteBatchDataRowMetadata( data_row_identifier=GlobalKey("fegh"), - schema_ids=["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"]) + schema_ids=["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"], + ) assert obj.model_dump() == { - "data_row_identifier": { - "id": "fegh", - "id_type": "GKEY" - }, + "data_row_identifier": {"id": "fegh", "id_type": "GKEY"}, "schema_ids": [ - "clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy" - ] + "clqh77tyk000008l2a9mjesa1", + "clqh784br000008jy0yuq04fy", + ], } def test_dict_delete_data_row_batch_by_alias(): obj = _DeleteBatchDataRowMetadata( data_row_identifier=UniqueId("abcd"), - schema_ids=["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"]) + schema_ids=["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"], + ) assert obj.model_dump(by_alias=True) == { - "dataRowIdentifier": { - "id": "abcd", - "idType": "ID" - }, - "schemaIds": ["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"] + "dataRowIdentifier": {"id": "abcd", "idType": "ID"}, + "schemaIds": ["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"], } obj = _DeleteBatchDataRowMetadata( data_row_identifier=GlobalKey("fegh"), - schema_ids=["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"]) + schema_ids=["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"], + ) assert obj.model_dump(by_alias=True) == { - "dataRowIdentifier": { - "id": "fegh", - "idType": "GKEY" - }, - "schemaIds": ["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"] + "dataRowIdentifier": {"id": "fegh", "idType": "GKEY"}, + "schemaIds": ["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"], } diff --git a/libs/labelbox/tests/unit/test_unit_descriptor_file_creator.py b/libs/labelbox/tests/unit/test_unit_descriptor_file_creator.py index 630d80573..621317ddd 100644 --- a/libs/labelbox/tests/unit/test_unit_descriptor_file_creator.py +++ b/libs/labelbox/tests/unit/test_unit_descriptor_file_creator.py @@ -3,7 +3,9 @@ from unittest.mock import MagicMock, Mock import pytest -from labelbox.schema.internal.descriptor_file_creator import DescriptorFileCreator +from labelbox.schema.internal.descriptor_file_creator import ( + DescriptorFileCreator, +) def test_chunk_down_by_bytes_row_too_large(): @@ -14,8 +16,9 @@ def test_chunk_down_by_bytes_row_too_large(): chunk = [{"row_data": "a"}] max_chunk_size_bytes = 1 - res = descriptor_file_creator._chunk_down_by_bytes(chunk, - max_chunk_size_bytes) + res = descriptor_file_creator._chunk_down_by_bytes( + chunk, max_chunk_size_bytes + ) assert [x for x in res] == [json.dumps([{"row_data": "a"}])] @@ -27,14 +30,12 @@ def test_chunk_down_by_bytes_more_chunks(): chunk = [{"row_data": "a"}, {"row_data": "b"}] max_chunk_size_bytes = len(json.dumps(chunk).encode("utf-8")) - 1 - res = descriptor_file_creator._chunk_down_by_bytes(chunk, - max_chunk_size_bytes) + res = descriptor_file_creator._chunk_down_by_bytes( + chunk, max_chunk_size_bytes + ) assert [x for x in res] == [ - json.dumps([{ - "row_data": "a" - }]), json.dumps([{ - "row_data": "b" - }]) + json.dumps([{"row_data": "a"}]), + json.dumps([{"row_data": "b"}]), ] @@ -46,11 +47,9 @@ def test_chunk_down_by_bytes_one_chunk(): chunk = [{"row_data": "a"}, {"row_data": "b"}] max_chunk_size_bytes = len(json.dumps(chunk).encode("utf-8")) - res = descriptor_file_creator._chunk_down_by_bytes(chunk, - max_chunk_size_bytes) - assert [x for x in res - ] == [json.dumps([{ - "row_data": "a" - }, { - "row_data": "b" - }])] + res = descriptor_file_creator._chunk_down_by_bytes( + chunk, max_chunk_size_bytes + ) + assert [x for x in res] == [ + json.dumps([{"row_data": "a"}, {"row_data": "b"}]) + ] diff --git a/libs/labelbox/tests/unit/test_unit_entity_meta.py b/libs/labelbox/tests/unit/test_unit_entity_meta.py index d24f985d9..06278951b 100644 --- a/libs/labelbox/tests/unit/test_unit_entity_meta.py +++ b/libs/labelbox/tests/unit/test_unit_entity_meta.py @@ -5,7 +5,6 @@ def test_illegal_cache_cond1(): - class TestEntityA(DbObject): test_entity_b = Relationship.ToOne("TestEntityB", cache=True) @@ -14,12 +13,13 @@ class TestEntityA(DbObject): class TestEntityB(DbObject): another_entity = Relationship.ToOne("AnotherEntity", cache=True) - assert "`test_entity_a` caches `test_entity_b` which caches `['another_entity']`" in str( - exc_info.value) + assert ( + "`test_entity_a` caches `test_entity_b` which caches `['another_entity']`" + in str(exc_info.value) + ) def test_illegal_cache_cond2(): - class TestEntityD(DbObject): another_entity = Relationship.ToOne("AnotherEntity", cache=True) @@ -28,5 +28,7 @@ class TestEntityD(DbObject): class TestEntityC(DbObject): test_entity_d = Relationship.ToOne("TestEntityD", cache=True) - assert "`test_entity_c` caches `test_entity_d` which caches `['another_entity']`" in str( - exc_info.value) + assert ( + "`test_entity_c` caches `test_entity_d` which caches `['another_entity']`" + in str(exc_info.value) + ) diff --git a/libs/labelbox/tests/unit/test_unit_export_filters.py b/libs/labelbox/tests/unit/test_unit_export_filters.py index 5986ae44e..3be78152e 100644 --- a/libs/labelbox/tests/unit/test_unit_export_filters.py +++ b/libs/labelbox/tests/unit/test_unit_export_filters.py @@ -8,33 +8,39 @@ def test_ids_filter(): client = MagicMock() filters = {"data_row_ids": ["id1", "id2"], "batch_ids": ["b1", "b2"]} - assert build_filters(client, filters) == [{ - "ids": ["id1", "id2"], - "operator": "is", - "type": "data_row_id", - }, { - "ids": ["b1", "b2"], - "operator": "is", - "type": "batch", - }] + assert build_filters(client, filters) == [ + { + "ids": ["id1", "id2"], + "operator": "is", + "type": "data_row_id", + }, + { + "ids": ["b1", "b2"], + "operator": "is", + "type": "batch", + }, + ] def test_ids_empty_filter(): client = MagicMock() filters = {"data_row_ids": [], "batch_ids": ["b1", "b2"]} - with pytest.raises(ValueError, - match="data_row_id filter expects a non-empty list."): + with pytest.raises( + ValueError, match="data_row_id filter expects a non-empty list." + ): build_filters(client, filters) def test_global_keys_filter(): client = MagicMock() filters = {"global_keys": ["id1", "id2"]} - assert build_filters(client, filters) == [{ - "ids": ["id1", "id2"], - "operator": "is", - "type": "global_key", - }] + assert build_filters(client, filters) == [ + { + "ids": ["id1", "id2"], + "operator": "is", + "type": "global_key", + } + ] def test_validations(): @@ -44,8 +50,7 @@ def test_validations(): "data_row_ids": ["id1", "id2"], } with pytest.raises( - ValueError, - match= - "data_rows and global_keys cannot both be present in export filters" + ValueError, + match="data_rows and global_keys cannot both be present in export filters", ): build_filters(client, filters) diff --git a/libs/labelbox/tests/unit/test_unit_label_import.py b/libs/labelbox/tests/unit/test_unit_label_import.py index feff4694c..b386a664d 100644 --- a/libs/labelbox/tests/unit/test_unit_label_import.py +++ b/libs/labelbox/tests/unit/test_unit_label_import.py @@ -13,27 +13,20 @@ def test_should_warn_user_about_unsupported_confidence(): { "uuid": "b862c586-8614-483c-b5e6-82810f70cac0", "schemaId": "ckrazcueb16og0z6609jj7y3y", - "dataRow": { - "id": "ckrazctum0z8a0ybc0b0o0g0v" - }, + "dataRow": {"id": "ckrazctum0z8a0ybc0b0o0g0v"}, "confidence": 0.851, - "bbox": { - "top": 1352, - "left": 2275, - "height": 350, - "width": 139 - } + "bbox": {"top": 1352, "left": 2275, "height": 350, "width": 139}, }, ] - with patch.object(LabelImport, '_create_label_import_from_bytes'): - with patch.object(logger, 'warning') as warning_mock: - LabelImport.create_from_objects(client=MagicMock(), - project_id=id, - name=id, - labels=labels) + with patch.object(LabelImport, "_create_label_import_from_bytes"): + with patch.object(logger, "warning") as warning_mock: + LabelImport.create_from_objects( + client=MagicMock(), project_id=id, name=id, labels=labels + ) warning_mock.assert_called_once() "Confidence scores are not supported in Label Import" in warning_mock.call_args_list[ - 0].args[0] + 0 + ].args[0] def test_invalid_labels_format(): @@ -43,19 +36,11 @@ def test_invalid_labels_format(): label = { "uuid": "b862c586-8614-483c-b5e6-82810f70cac0", "schemaId": "ckrazcueb16og0z6609jj7y3y", - "dataRow": { - "id": "ckrazctum0z8a0ybc0b0o0g0v" - }, - "bbox": { - "top": 1352, - "left": 2275, - "height": 350, - "width": 139 - } + "dataRow": {"id": "ckrazctum0z8a0ybc0b0o0g0v"}, + "bbox": {"top": 1352, "left": 2275, "height": 350, "width": 139}, } - with patch.object(LabelImport, '_create_label_import_from_bytes'): + with patch.object(LabelImport, "_create_label_import_from_bytes"): with pytest.raises(TypeError): - LabelImport.create_from_objects(client=MagicMock(), - project_id=id, - name=id, - labels=label) + LabelImport.create_from_objects( + client=MagicMock(), project_id=id, name=id, labels=label + ) diff --git a/libs/labelbox/tests/unit/test_unit_ontology.py b/libs/labelbox/tests/unit/test_unit_ontology.py index ac53827c6..0566ad623 100644 --- a/libs/labelbox/tests/unit/test_unit_ontology.py +++ b/libs/labelbox/tests/unit/test_unit_ontology.py @@ -5,183 +5,187 @@ from itertools import product _SAMPLE_ONTOLOGY = { - "tools": [{ - "schemaNodeId": None, - "featureSchemaId": None, - "required": False, - "name": "poly", - "color": "#FF0000", - "tool": "polygon", - "classifications": [] - }, { - "schemaNodeId": None, - "featureSchemaId": None, - "required": False, - "name": "segment", - "color": "#FF0000", - "tool": "superpixel", - "classifications": [] - }, { - "schemaNodeId": - None, - "featureSchemaId": - None, - "required": - False, - "name": - "bbox", - "color": - "#FF0000", - "tool": - "rectangle", - "classifications": [{ - "schemaNodeId": - None, - "featureSchemaId": - None, - "required": - True, - "instructions": - "nested classification", - "name": - "nested classification", - "type": - "radio", - 'uiMode': - "searchable", - "options": [{ - "schemaNodeId": - None, - "featureSchemaId": - None, - "label": - "first", - "value": - "first", - "options": [{ + "tools": [ + { + "schemaNodeId": None, + "featureSchemaId": None, + "required": False, + "name": "poly", + "color": "#FF0000", + "tool": "polygon", + "classifications": [], + }, + { + "schemaNodeId": None, + "featureSchemaId": None, + "required": False, + "name": "segment", + "color": "#FF0000", + "tool": "superpixel", + "classifications": [], + }, + { + "schemaNodeId": None, + "featureSchemaId": None, + "required": False, + "name": "bbox", + "color": "#FF0000", + "tool": "rectangle", + "classifications": [ + { + "schemaNodeId": None, + "featureSchemaId": None, + "required": True, + "instructions": "nested classification", + "name": "nested classification", + "type": "radio", + "uiMode": "searchable", + "options": [ + { + "schemaNodeId": None, + "featureSchemaId": None, + "label": "first", + "value": "first", + "options": [ + { + "schemaNodeId": None, + "featureSchemaId": None, + "required": False, + "instructions": "nested nested text", + "name": "nested nested text", + "type": "text", + "options": [], + } + ], + }, + { + "schemaNodeId": None, + "featureSchemaId": None, + "label": "second", + "value": "second", + "options": [], + }, + ], + }, + { "schemaNodeId": None, "featureSchemaId": None, - "required": False, - "instructions": "nested nested text", - "name": "nested nested text", + "required": True, + "instructions": "nested text", + "name": "nested text", "type": "text", - "options": [] - }] - }, { - "schemaNodeId": None, - "featureSchemaId": None, - "label": "second", - "value": "second", - "options": [] - }] - }, { + "options": [], + }, + ], + }, + { "schemaNodeId": None, "featureSchemaId": None, - "required": True, - "instructions": "nested text", - "name": "nested text", - "type": "text", - "options": [] - }] - }, { - "schemaNodeId": None, - "featureSchemaId": None, - "required": False, - "name": "dot", - "color": "#FF0000", - "tool": "point", - "classifications": [] - }, { - "schemaNodeId": None, - "featureSchemaId": None, - "required": False, - "name": "polyline", - "color": "#FF0000", - "tool": "line", - "classifications": [] - }, { - "schemaNodeId": None, - "featureSchemaId": None, - "required": False, - "name": "ner", - "color": "#FF0000", - "tool": "named-entity", - "classifications": [] - }], - "classifications": [{ - "schemaNodeId": - None, - "featureSchemaId": - None, - "required": - True, - "instructions": - "This is a question.", - "name": - "This is a question.", - "type": - "radio", - "scope": - "global", - 'uiMode': - "searchable", - "options": [{ + "required": False, + "name": "dot", + "color": "#FF0000", + "tool": "point", + "classifications": [], + }, + { "schemaNodeId": None, "featureSchemaId": None, - "label": "yes", - "value": "definitely yes", - "options": [] - }, { + "required": False, + "name": "polyline", + "color": "#FF0000", + "tool": "line", + "classifications": [], + }, + { "schemaNodeId": None, "featureSchemaId": None, - "label": "no", - "value": "definitely not", - "options": [] - }] - }] + "required": False, + "name": "ner", + "color": "#FF0000", + "tool": "named-entity", + "classifications": [], + }, + ], + "classifications": [ + { + "schemaNodeId": None, + "featureSchemaId": None, + "required": True, + "instructions": "This is a question.", + "name": "This is a question.", + "type": "radio", + "scope": "global", + "uiMode": "searchable", + "options": [ + { + "schemaNodeId": None, + "featureSchemaId": None, + "label": "yes", + "value": "definitely yes", + "options": [], + }, + { + "schemaNodeId": None, + "featureSchemaId": None, + "label": "no", + "value": "definitely not", + "options": [], + }, + ], + } + ], } @pytest.mark.parametrize("tool_type", list(Tool.Type)) def test_create_tool(tool_type) -> None: t = Tool(tool=tool_type, name="tool") - assert (t.tool == tool_type) + assert t.tool == tool_type @pytest.mark.parametrize("class_type", list(Classification.Type)) def test_create_classification(class_type) -> None: c = Classification(class_type=class_type, name="classification") - assert (c.class_type == class_type) + assert c.class_type == class_type + -@pytest.mark.parametrize("ui_mode_type, class_type", list(product(list(Classification.UIMode), list(Classification.Type)))) +@pytest.mark.parametrize( + "ui_mode_type, class_type", + list(product(list(Classification.UIMode), list(Classification.Type))), +) def test_create_classification_with_ui_mode(ui_mode_type, class_type) -> None: - c = Classification(name="classification", class_type=class_type, ui_mode=ui_mode_type) - assert (c.ui_mode == ui_mode_type) + c = Classification( + name="classification", class_type=class_type, ui_mode=ui_mode_type + ) + assert c.ui_mode == ui_mode_type -@pytest.mark.parametrize("value, expected_value, typing", - [(3, 3, int), ("string", "string", str)]) +@pytest.mark.parametrize( + "value, expected_value, typing", [(3, 3, int), ("string", "string", str)] +) def test_create_option_with_value(value, expected_value, typing) -> None: o = Option(value=value) - assert (o.value == expected_value) - assert (o.value == o.label) + assert o.value == expected_value + assert o.value == o.label -@pytest.mark.parametrize("value, label, expected_value, typing", - [(3, 2, 3, int), - ("string", "another string", "string", str)]) -def test_create_option_with_value_and_label(value, label, expected_value, - typing) -> None: +@pytest.mark.parametrize( + "value, label, expected_value, typing", + [(3, 2, 3, int), ("string", "another string", "string", str)], +) +def test_create_option_with_value_and_label( + value, label, expected_value, typing +) -> None: o = Option(value=value, label=label) - assert (o.value == expected_value) + assert o.value == expected_value assert o.value != o.label assert isinstance(o.value, typing) def test_create_empty_ontology() -> None: o = OntologyBuilder() - assert (o.tools == []) - assert (o.classifications == []) + assert o.tools == [] + assert o.classifications == [] def test_add_ontology_tool() -> None: @@ -193,7 +197,7 @@ def test_add_ontology_tool() -> None: assert len(o.tools) == 2 for tool in o.tools: - assert (type(tool) == Tool) + assert type(tool) == Tool with pytest.raises(InconsistentOntologyException) as exc: o.add_tool(Tool(tool=Tool.Type.BBOX, name="bounding box")) @@ -203,19 +207,22 @@ def test_add_ontology_tool() -> None: def test_add_ontology_classification() -> None: o = OntologyBuilder() o.add_classification( - Classification(class_type=Classification.Type.TEXT, name="text")) + Classification(class_type=Classification.Type.TEXT, name="text") + ) second_classification = Classification( - class_type=Classification.Type.CHECKLIST, name="checklist") + class_type=Classification.Type.CHECKLIST, name="checklist" + ) o.add_classification(second_classification) assert len(o.classifications) == 2 for classification in o.classifications: - assert (type(classification) == Classification) + assert type(classification) == Classification with pytest.raises(InconsistentOntologyException) as exc: o.add_classification( - Classification(class_type=Classification.Type.TEXT, name="text")) + Classification(class_type=Classification.Type.TEXT, name="text") + ) assert "Duplicate classification name" in str(exc.value) @@ -253,8 +260,9 @@ def test_option_add_option() -> None: def test_ontology_asdict() -> None: - assert OntologyBuilder.from_dict( - _SAMPLE_ONTOLOGY).asdict() == _SAMPLE_ONTOLOGY + assert ( + OntologyBuilder.from_dict(_SAMPLE_ONTOLOGY).asdict() == _SAMPLE_ONTOLOGY + ) def test_classification_using_instructions_instead_of_name_shows_warning(): diff --git a/libs/labelbox/tests/unit/test_unit_ontology_kind.py b/libs/labelbox/tests/unit/test_unit_ontology_kind.py index 51e2cf214..54cec0812 100644 --- a/libs/labelbox/tests/unit/test_unit_ontology_kind.py +++ b/libs/labelbox/tests/unit/test_unit_ontology_kind.py @@ -1,4 +1,8 @@ -from labelbox.schema.ontology_kind import OntologyKind, EditorTaskType, EditorTaskTypeMapper +from labelbox.schema.ontology_kind import ( + OntologyKind, + EditorTaskType, + EditorTaskTypeMapper, +) from labelbox.schema.media_type import MediaType @@ -6,17 +10,20 @@ def test_ontology_kind_conversions_from_editor_task_type(): ontology_kind = OntologyKind.ModelEvaluation media_type = MediaType.Conversational editor_task_type = EditorTaskTypeMapper.to_editor_task_type( - ontology_kind, media_type) + ontology_kind, media_type + ) assert editor_task_type == EditorTaskType.ModelChatEvaluation ontology_kind = OntologyKind.Missing media_type = MediaType.Image editor_task_type = EditorTaskTypeMapper.to_editor_task_type( - ontology_kind, media_type) + ontology_kind, media_type + ) assert editor_task_type == EditorTaskType.Missing ontology_kind = OntologyKind.ModelEvaluation media_type = MediaType.Video editor_task_type = EditorTaskTypeMapper.to_editor_task_type( - ontology_kind, media_type) + ontology_kind, media_type + ) assert editor_task_type == EditorTaskType.Missing diff --git a/libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py b/libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py index f9f9a0959..7f6d29d5a 100644 --- a/libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py +++ b/libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py @@ -22,8 +22,11 @@ def test_validate_labeling_parameter_overrides_invalid_data(): def test_validate_labeling_parameter_overrides_invalid_priority(): mock_data_row = MagicMock(spec=DataRow) mock_data_row.uid = "abc" - data = [(mock_data_row, "invalid"), (UniqueId("efg"), 2), - (GlobalKey("hij"), 3)] + data = [ + (mock_data_row, "invalid"), + (UniqueId("efg"), 2), + (GlobalKey("hij"), 3), + ] with pytest.raises(TypeError): validate_labeling_parameter_overrides(data) @@ -31,7 +34,10 @@ def test_validate_labeling_parameter_overrides_invalid_priority(): def test_validate_labeling_parameter_overrides_invalid_tuple_length(): mock_data_row = MagicMock(spec=DataRow) mock_data_row.uid = "abc" - data = [(mock_data_row, "invalid"), (UniqueId("efg"), 2), - (GlobalKey("hij"))] + data = [ + (mock_data_row, "invalid"), + (UniqueId("efg"), 2), + (GlobalKey("hij")), + ] with pytest.raises(TypeError): validate_labeling_parameter_overrides(data) diff --git a/libs/labelbox/tests/unit/test_unit_query.py b/libs/labelbox/tests/unit/test_unit_query.py index 12db00d2b..83bfeff8a 100644 --- a/libs/labelbox/tests/unit/test_unit_query.py +++ b/libs/labelbox/tests/unit/test_unit_query.py @@ -24,13 +24,15 @@ def test_query_where(): assert q.startswith("x(where: {name_gt: $param_0}){") assert p == {"param_0": ("name", Project.name)} - q, p = query.Query("x", Project, - (Project.name != "name") & (Project.uid <= 42)).format() + q, p = query.Query( + "x", Project, (Project.name != "name") & (Project.uid <= 42) + ).format() assert q.startswith( - "x(where: {AND: [{name_not: $param_0}, {id_lte: $param_1}]}") + "x(where: {AND: [{name_not: $param_0}, {id_lte: $param_1}]}" + ) assert p == { "param_0": ("name", Project.name), - "param_1": (42, Project.uid) + "param_1": (42, Project.uid), } @@ -38,8 +40,9 @@ def test_query_param_declaration(): q, _ = query.Query("x", Project, Project.name > "name").format_top("y") assert q.startswith("query yPyApi($param_0: String!){x") - q, _ = query.Query("x", Project, (Project.name > "name") & - (Project.uid == 42)).format_top("y") + q, _ = query.Query( + "x", Project, (Project.name > "name") & (Project.uid == 42) + ).format_top("y") assert q.startswith("query yPyApi($param_0: String!, $param_1: ID!){x") diff --git a/libs/labelbox/tests/unit/test_unit_search_filters.py b/libs/labelbox/tests/unit/test_unit_search_filters.py index eba8d4db8..b2230bb7f 100644 --- a/libs/labelbox/tests/unit/test_unit_search_filters.py +++ b/libs/labelbox/tests/unit/test_unit_search_filters.py @@ -1,37 +1,68 @@ from datetime import datetime from labelbox.schema.labeling_service import LabelingServiceStatus -from labelbox.schema.search_filters import IntegerValue, RangeDateTimeOperatorWithSingleValue, RangeOperatorWithSingleValue, DateRange, RangeOperatorWithValue, DateRangeValue, DateValue, IdOperator, OperationType, OrganizationFilter, ProjectStageFilter, SharedWithOrganizationFilter, TagFilter, TaskCompletedCountFilter, TaskRemainingCountFilter, WorkforceRequestedDateFilter, WorkforceRequestedDateRangeFilter, WorkforceStageUpdatedFilter, WorkforceStageUpdatedRangeFilter, WorkspaceFilter, build_search_filter +from labelbox.schema.search_filters import ( + IntegerValue, + RangeDateTimeOperatorWithSingleValue, + RangeOperatorWithSingleValue, + DateRange, + RangeOperatorWithValue, + DateRangeValue, + DateValue, + IdOperator, + OperationType, + OrganizationFilter, + ProjectStageFilter, + SharedWithOrganizationFilter, + TagFilter, + TaskCompletedCountFilter, + TaskRemainingCountFilter, + WorkforceRequestedDateFilter, + WorkforceRequestedDateRangeFilter, + WorkforceStageUpdatedFilter, + WorkforceStageUpdatedRangeFilter, + WorkspaceFilter, + build_search_filter, +) from labelbox.utils import format_iso_datetime import pytest def test_id_filters(): filters = [ - OrganizationFilter(operator=IdOperator.Is, - values=["clphb4vd7000cd2wv1ktu5cwa"]), - SharedWithOrganizationFilter(operator=IdOperator.Is, - values=["clphb4vd7000cd2wv1ktu5cwa"]), - WorkspaceFilter(operator=IdOperator.Is, - values=["clphb4vd7000cd2wv1ktu5cwa"]), + OrganizationFilter( + operator=IdOperator.Is, values=["clphb4vd7000cd2wv1ktu5cwa"] + ), + SharedWithOrganizationFilter( + operator=IdOperator.Is, values=["clphb4vd7000cd2wv1ktu5cwa"] + ), + WorkspaceFilter( + operator=IdOperator.Is, values=["clphb4vd7000cd2wv1ktu5cwa"] + ), TagFilter(operator=IdOperator.Is, values=["cls1vkrw401ab072vg2pq3t5d"]), - ProjectStageFilter(operator=IdOperator.Is, - values=[LabelingServiceStatus.Requested]), + ProjectStageFilter( + operator=IdOperator.Is, values=[LabelingServiceStatus.Requested] + ), ] - assert build_search_filter( - filters - ) == '[{type: "organization_id", operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"]}, {type: "shared_with_organizations", operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"]}, {type: "workspace", operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"]}, {type: "tag", operator: "is", values: ["cls1vkrw401ab072vg2pq3t5d"]}, {type: "stage", operator: "is", values: ["REQUESTED"]}]' + assert ( + build_search_filter(filters) + == '[{type: "organization_id", operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"]}, {type: "shared_with_organizations", operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"]}, {type: "workspace", operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"]}, {type: "tag", operator: "is", values: ["cls1vkrw401ab072vg2pq3t5d"]}, {type: "stage", operator: "is", values: ["REQUESTED"]}]' + ) def test_stage_filter_with_invalid_values(): with pytest.raises( - ValueError, - match="is not a valid value for ProjectStageFilter") as e: - _ = ProjectStageFilter(operator=IdOperator.Is, - values=[ - LabelingServiceStatus.Requested, - LabelingServiceStatus.Missing - ]), + ValueError, match="is not a valid value for ProjectStageFilter" + ) as e: + _ = ( + ProjectStageFilter( + operator=IdOperator.Is, + values=[ + LabelingServiceStatus.Requested, + LabelingServiceStatus.Missing, + ], + ), + ) def test_date_filters(): @@ -39,46 +70,80 @@ def test_date_filters(): local_time_end = datetime.strptime("2025-01-01", "%Y-%m-%d") filters = [ - WorkforceRequestedDateFilter(value=DateValue( - operator=RangeDateTimeOperatorWithSingleValue.GreaterThanOrEqual, - value=local_time_start)), - WorkforceStageUpdatedFilter(value=DateValue( - operator=RangeDateTimeOperatorWithSingleValue.LessThanOrEqual, - value=local_time_end)), + WorkforceRequestedDateFilter( + value=DateValue( + operator=RangeDateTimeOperatorWithSingleValue.GreaterThanOrEqual, + value=local_time_start, + ) + ), + WorkforceStageUpdatedFilter( + value=DateValue( + operator=RangeDateTimeOperatorWithSingleValue.LessThanOrEqual, + value=local_time_end, + ) + ), ] expected_start = format_iso_datetime(local_time_start) expected_end = format_iso_datetime(local_time_end) - expected = '[{type: "workforce_requested_at", value: {operator: "GREATER_THAN_OR_EQUAL", value: "' + expected_start + '"}}, {type: "workforce_stage_updated_at", value: {operator: "LESS_THAN_OR_EQUAL", value: "' + expected_end + '"}}]' + expected = ( + '[{type: "workforce_requested_at", value: {operator: "GREATER_THAN_OR_EQUAL", value: "' + + expected_start + + '"}}, {type: "workforce_stage_updated_at", value: {operator: "LESS_THAN_OR_EQUAL", value: "' + + expected_end + + '"}}]' + ) assert build_search_filter(filters) == expected def test_date_range_filters(): filters = [ - WorkforceRequestedDateRangeFilter(value=DateRangeValue( - operator=RangeOperatorWithValue.Between, - value=DateRange(min=datetime.strptime("2024-01-01T00:00:00-0800", - "%Y-%m-%dT%H:%M:%S%z"), - max=datetime.strptime("2025-01-01T00:00:00-0800", - "%Y-%m-%dT%H:%M:%S%z")))), - WorkforceStageUpdatedRangeFilter(value=DateRangeValue( - operator=RangeOperatorWithValue.Between, - value=DateRange(min=datetime.strptime("2024-01-01T00:00:00-0800", - "%Y-%m-%dT%H:%M:%S%z"), - max=datetime.strptime("2025-01-01T00:00:00-0800", - "%Y-%m-%dT%H:%M:%S%z")))), + WorkforceRequestedDateRangeFilter( + value=DateRangeValue( + operator=RangeOperatorWithValue.Between, + value=DateRange( + min=datetime.strptime( + "2024-01-01T00:00:00-0800", "%Y-%m-%dT%H:%M:%S%z" + ), + max=datetime.strptime( + "2025-01-01T00:00:00-0800", "%Y-%m-%dT%H:%M:%S%z" + ), + ), + ) + ), + WorkforceStageUpdatedRangeFilter( + value=DateRangeValue( + operator=RangeOperatorWithValue.Between, + value=DateRange( + min=datetime.strptime( + "2024-01-01T00:00:00-0800", "%Y-%m-%dT%H:%M:%S%z" + ), + max=datetime.strptime( + "2025-01-01T00:00:00-0800", "%Y-%m-%dT%H:%M:%S%z" + ), + ), + ) + ), ] - assert build_search_filter( - filters - ) == '[{type: "workforce_requested_at", value: {operator: "BETWEEN", value: {min: "2024-01-01T08:00:00Z", max: "2025-01-01T08:00:00Z"}}}, {type: "workforce_stage_updated_at", value: {operator: "BETWEEN", value: {min: "2024-01-01T08:00:00Z", max: "2025-01-01T08:00:00Z"}}}]' + assert ( + build_search_filter(filters) + == '[{type: "workforce_requested_at", value: {operator: "BETWEEN", value: {min: "2024-01-01T08:00:00Z", max: "2025-01-01T08:00:00Z"}}}, {type: "workforce_stage_updated_at", value: {operator: "BETWEEN", value: {min: "2024-01-01T08:00:00Z", max: "2025-01-01T08:00:00Z"}}}]' + ) def test_task_count_filters(): filters = [ - TaskCompletedCountFilter(value=IntegerValue( - operator=RangeOperatorWithSingleValue.GreaterThanOrEqual, value=1)), - TaskRemainingCountFilter(value=IntegerValue( - operator=RangeOperatorWithSingleValue.LessThanOrEqual, value=10)), + TaskCompletedCountFilter( + value=IntegerValue( + operator=RangeOperatorWithSingleValue.GreaterThanOrEqual, + value=1, + ) + ), + TaskRemainingCountFilter( + value=IntegerValue( + operator=RangeOperatorWithSingleValue.LessThanOrEqual, value=10 + ) + ), ] expected = '[{type: "task_completed_count", value: {operator: "GREATER_THAN_OR_EQUAL", value: 1}}, {type: "task_remaining_count", value: {operator: "LESS_THAN_OR_EQUAL", value: 10}}]' diff --git a/libs/labelbox/tests/unit/test_unit_webhook.py b/libs/labelbox/tests/unit/test_unit_webhook.py index 405955ce6..ae1b6884d 100644 --- a/libs/labelbox/tests/unit/test_unit_webhook.py +++ b/libs/labelbox/tests/unit/test_unit_webhook.py @@ -13,8 +13,7 @@ def test_webhook_create_with_no_secret(rand_gen): with pytest.raises(ValueError) as exc_info: Webhook.create(client, topics, url, secret, project) - assert str(exc_info.value) == \ - "Secret must be a non-empty string." + assert str(exc_info.value) == "Secret must be a non-empty string." def test_webhook_create_with_no_topics(rand_gen): @@ -26,8 +25,7 @@ def test_webhook_create_with_no_topics(rand_gen): with pytest.raises(ValueError) as exc_info: Webhook.create(client, topics, url, secret, project) - assert str(exc_info.value) == \ - "Topics must be a non-empty list." + assert str(exc_info.value) == "Topics must be a non-empty list." def test_webhook_create_with_no_url(rand_gen): @@ -39,5 +37,4 @@ def test_webhook_create_with_no_url(rand_gen): with pytest.raises(ValueError) as exc_info: Webhook.create(client, topics, url, secret, project) - assert str(exc_info.value) == \ - "URL must be a non-empty string." + assert str(exc_info.value) == "URL must be a non-empty string." diff --git a/libs/labelbox/tests/unit/test_utils.py b/libs/labelbox/tests/unit/test_utils.py index dfd72c335..969f3a46b 100644 --- a/libs/labelbox/tests/unit/test_utils.py +++ b/libs/labelbox/tests/unit/test_utils.py @@ -1,21 +1,34 @@ import pytest -from labelbox.utils import format_iso_datetime, format_iso_from_string, sentence_case +from labelbox.utils import ( + format_iso_datetime, + format_iso_from_string, + sentence_case, +) -@pytest.mark.parametrize('datetime_str, expected_datetime_str', - [('2011-11-04T00:05:23Z', '2011-11-04T00:05:23Z'), - ('2011-11-04T00:05:23+00:00', '2011-11-04T00:05:23Z'), - ('2011-11-04T00:05:23+05:00', '2011-11-03T19:05:23Z'), - ('2011-11-04T00:05:23', '2011-11-04T00:05:23Z')]) +@pytest.mark.parametrize( + "datetime_str, expected_datetime_str", + [ + ("2011-11-04T00:05:23Z", "2011-11-04T00:05:23Z"), + ("2011-11-04T00:05:23+00:00", "2011-11-04T00:05:23Z"), + ("2011-11-04T00:05:23+05:00", "2011-11-03T19:05:23Z"), + ("2011-11-04T00:05:23", "2011-11-04T00:05:23Z"), + ], +) def test_datetime_parsing(datetime_str, expected_datetime_str): # NOTE I would normally not take 'expected' using another function from sdk code, but in this case this is exactly the usage in _validate_parse_datetime - assert format_iso_datetime( - format_iso_from_string(datetime_str)) == expected_datetime_str + assert ( + format_iso_datetime(format_iso_from_string(datetime_str)) + == expected_datetime_str + ) @pytest.mark.parametrize( - 'str, expected_str', - [('AUDIO', 'Audio'), - ('LLM_PROMPT_RESPONSE_CREATION', 'Llm prompt response creation')]) + "str, expected_str", + [ + ("AUDIO", "Audio"), + ("LLM_PROMPT_RESPONSE_CREATION", "Llm prompt response creation"), + ], +) def test_sentence_case(str, expected_str): assert sentence_case(str) == expected_str diff --git a/libs/labelbox/tests/utils.py b/libs/labelbox/tests/utils.py index 6fa2a8d8d..595fa0c76 100644 --- a/libs/labelbox/tests/utils.py +++ b/libs/labelbox/tests/utils.py @@ -14,9 +14,9 @@ def remove_keys_recursive(d, keys): # NOTE this uses quite a primitive check for cuids but I do not think it is worth coming up with a better one # Also this function is NOT written with performance in mind, good for small to mid size dicts like we have in our test def rename_cuid_key_recursive(d): - new_key = '' + new_key = "" for k in list(d.keys()): - if len(k) == 25 and not k.isalpha(): #primitive check for cuid + if len(k) == 25 and not k.isalpha(): # primitive check for cuid d[new_key] = d.pop(k) for k, v in d.items(): if isinstance(v, dict): @@ -27,4 +27,4 @@ def rename_cuid_key_recursive(d): rename_cuid_key_recursive(i) -INTEGRATION_SNAPSHOT_DIRECTORY = 'tests/integration/snapshots' +INTEGRATION_SNAPSHOT_DIRECTORY = "tests/integration/snapshots" From 1466516508e17c4a2a2f2f3ba0bb720677df91d5 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Wed, 11 Sep 2024 15:07:57 -0500 Subject: [PATCH 3/8] fixed error --- .../src/labelbox/schema/bulk_import_request.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/bulk_import_request.py b/libs/labelbox/src/labelbox/schema/bulk_import_request.py index 7caa2c6eb..44ac7cd6a 100644 --- a/libs/labelbox/src/labelbox/schema/bulk_import_request.py +++ b/libs/labelbox/src/labelbox/schema/bulk_import_request.py @@ -787,10 +787,8 @@ def validate_feature_schemas( # A union with custom construction logic to improve error messages class NDClassification( SpecialUnion, - Type[ - Union[ # type: ignore - NDText, NDRadio, NDChecklist - ] + Type[ # type: ignore + Union[NDText, NDRadio, NDChecklist] ], ): ... @@ -966,8 +964,8 @@ class NDMask(NDBaseTool): # A union with custom construction logic to improve error messages class NDTool( SpecialUnion, - Type[ - Union[ # type: ignore + Type[ # type: ignore + Union[ NDMask, NDTextEntity, NDPoint, @@ -981,10 +979,8 @@ class NDTool( class NDAnnotation( SpecialUnion, - Type[ - Union[ # type: ignore - NDTool, NDClassification - ] + Type[ # type: ignore + Union[NDTool, NDClassification] ], ): @classmethod From 561f46369b89fd65bd21048d210e0c69ddcebc4f Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Wed, 11 Sep 2024 15:00:54 -0500 Subject: [PATCH 4/8] testing workflow --- .github/workflows/python-package-shared.yml | 7 +++++-- libs/labelbox/pyproject.toml | 7 ++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/.github/workflows/python-package-shared.yml b/.github/workflows/python-package-shared.yml index 4311020d8..acd30b299 100644 --- a/.github/workflows/python-package-shared.yml +++ b/.github/workflows/python-package-shared.yml @@ -18,7 +18,7 @@ on: test-env: required: true type: string - fixture-profile: + fixture-profile: required: true type: boolean @@ -36,6 +36,9 @@ jobs: - name: Linting working-directory: libs/labelbox run: rye run lint + - name: Format + working-directory: libs/labelbox + run: rye fmt --check integration: runs-on: ubuntu-latest concurrency: @@ -78,4 +81,4 @@ jobs: run: | rye sync -f --features labelbox/data rye run unit -n 32 - rye run data -n 32 \ No newline at end of file + rye run data -n 32 diff --git a/libs/labelbox/pyproject.toml b/libs/labelbox/pyproject.toml index ac167cdcb..771117a01 100644 --- a/libs/labelbox/pyproject.toml +++ b/libs/labelbox/pyproject.toml @@ -64,7 +64,6 @@ build-backend = "hatchling.build" [tool.rye] managed = true dev-dependencies = [ - "yapf>=0.40.2", "mypy>=1.9.0", "types-pillow>=10.2.0.20240311", "types-python-dateutil>=2.9.0.20240316", @@ -72,6 +71,9 @@ dev-dependencies = [ "types-tqdm>=4.66.0.20240106", ] +[tool.ruff] +line-length = 80 + [tool.rye.scripts] unit = "pytest tests/unit" # https://github.com/Labelbox/labelbox-python/blob/7c84fdffbc14fd1f69d2a6abdcc0087dc557fa4e/Makefile @@ -87,9 +89,8 @@ unit = "pytest tests/unit" # LABELBOX_TEST_BASE_URL="http://host.docker.internal:8080" \ integration = { cmd = "pytest tests/integration" } data = { cmd = "pytest tests/data" } -yapf-lint = "yapf tests src -i --verbose --recursive --parallel --style \"google\"" mypy-lint = "mypy src --pretty --show-error-codes --non-interactive --install-types" -lint = { chain = ["yapf-lint", "mypy-lint"] } +lint = { chain = ["mypy-lint"] } test = { chain = ["lint", "unit", "integration"] } [tool.hatch.metadata] From a12c420a69bee18c9c9ec5ba707f34d4b4e9f4cb Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Wed, 11 Sep 2024 15:05:11 -0500 Subject: [PATCH 5/8] reformatted --- libs/labelbox/src/labelbox/__init__.py | 43 +- libs/labelbox/src/labelbox/adv_client.py | 62 +- libs/labelbox/src/labelbox/client.py | 1192 ++++++++----- .../data/annotation_types/__init__.py | 9 +- .../data/annotation_types/annotation.py | 4 +- .../data/annotation_types/base_annotation.py | 6 +- .../classification/__init__.py | 3 +- .../classification/classification.py | 19 +- .../data/annotation_types/collection.py | 43 +- .../data/annotation_types/data/__init__.py | 2 +- .../data/annotation_types/data/audio.py | 2 +- .../data/annotation_types/data/base_data.py | 1 + .../data/annotation_types/data/dicom.py | 2 +- .../data/annotation_types/data/document.py | 2 +- .../data/generic_data_row_data.py | 6 +- .../data/annotation_types/data/html.py | 2 +- .../data/llm_prompt_creation.py | 2 +- .../data/llm_prompt_response_creation.py | 5 +- .../data/llm_response_creation.py | 2 +- .../data/annotation_types/data/raster.py | 55 +- .../data/annotation_types/data/text.py | 25 +- .../data/annotation_types/data/tiled_image.py | 330 ++-- .../data/annotation_types/data/video.py | 41 +- .../labelbox/data/annotation_types/feature.py | 1 + .../annotation_types/geometry/geometry.py | 41 +- .../data/annotation_types/geometry/line.py | 43 +- .../data/annotation_types/geometry/mask.py | 42 +- .../data/annotation_types/geometry/point.py | 31 +- .../data/annotation_types/geometry/polygon.py | 31 +- .../annotation_types/geometry/rectangle.py | 54 +- .../labelbox/data/annotation_types/label.py | 121 +- .../llm_prompt_response/__init__.py | 2 +- .../llm_prompt_response/prompt.py | 12 +- .../data/annotation_types/metrics/__init__.py | 6 +- .../data/annotation_types/metrics/base.py | 20 +- .../metrics/confusion_matrix.py | 17 +- .../data/annotation_types/metrics/scalar.py | 29 +- .../src/labelbox/data/annotation_types/mmc.py | 11 +- .../ner/conversation_entity.py | 5 +- .../annotation_types/ner/document_entity.py | 3 +- .../data/annotation_types/ner/text_entity.py | 11 +- .../data/annotation_types/relationship.py | 6 +- .../labelbox/data/annotation_types/types.py | 19 +- .../labelbox/data/annotation_types/video.py | 105 +- libs/labelbox/src/labelbox/data/generator.py | 7 +- .../src/labelbox/data/metrics/__init__.py | 5 +- .../metrics/confusion_matrix/calculation.py | 171 +- .../confusion_matrix/confusion_matrix.py | 76 +- .../src/labelbox/data/metrics/group.py | 71 +- .../labelbox/data/metrics/iou/calculation.py | 199 ++- .../src/labelbox/data/metrics/iou/iou.py | 53 +- libs/labelbox/src/labelbox/data/mixins.py | 13 +- libs/labelbox/src/labelbox/data/ontology.py | 70 +- .../data/serialization/coco/annotation.py | 19 +- .../data/serialization/coco/categories.py | 3 +- .../data/serialization/coco/converter.py | 78 +- .../labelbox/data/serialization/coco/image.py | 2 +- .../serialization/coco/instance_dataset.py | 188 +- .../serialization/coco/panoptic_dataset.py | 149 +- .../labelbox/data/serialization/coco/path.py | 2 +- .../data/serialization/ndjson/base.py | 12 +- .../serialization/ndjson/classification.py | 436 +++-- .../data/serialization/ndjson/converter.py | 50 +- .../data/serialization/ndjson/label.py | 275 ++- .../data/serialization/ndjson/metric.py | 110 +- .../labelbox/data/serialization/ndjson/mmc.py | 36 +- .../data/serialization/ndjson/objects.py | 871 +++++---- .../data/serialization/ndjson/relationship.py | 49 +- libs/labelbox/src/labelbox/exceptions.py | 51 +- libs/labelbox/src/labelbox/orm/comparison.py | 46 +- libs/labelbox/src/labelbox/orm/db_object.py | 122 +- libs/labelbox/src/labelbox/orm/model.py | 95 +- libs/labelbox/src/labelbox/orm/query.py | 214 ++- libs/labelbox/src/labelbox/pagination.py | 80 +- libs/labelbox/src/labelbox/parser.py | 7 +- libs/labelbox/src/labelbox/schema/__init__.py | 2 +- .../src/labelbox/schema/annotation_import.py | 395 +++-- .../src/labelbox/schema/asset_attachment.py | 31 +- libs/labelbox/src/labelbox/schema/batch.py | 117 +- .../labelbox/src/labelbox/schema/benchmark.py | 14 +- .../labelbox/schema/bulk_import_request.py | 503 +++--- libs/labelbox/src/labelbox/schema/catalog.py | 135 +- .../schema/confidence_presence_checker.py | 5 +- .../labelbox/schema/create_batches_task.py | 6 +- libs/labelbox/src/labelbox/schema/data_row.py | 247 +-- .../src/labelbox/schema/data_row_metadata.py | 449 +++-- libs/labelbox/src/labelbox/schema/dataset.py | 403 +++-- .../labelbox/src/labelbox/schema/embedding.py | 11 +- libs/labelbox/src/labelbox/schema/enums.py | 29 +- .../src/labelbox/schema/export_filters.py | 142 +- .../src/labelbox/schema/export_params.py | 11 +- .../src/labelbox/schema/export_task.py | 246 ++- .../src/labelbox/schema/foundry/app.py | 4 +- .../labelbox/schema/foundry/foundry_client.py | 28 +- .../src/labelbox/schema/foundry/model.py | 2 +- .../src/labelbox/schema/iam_integration.py | 6 +- libs/labelbox/src/labelbox/schema/id_type.py | 3 +- .../src/labelbox/schema/identifiables.py | 6 +- .../schema/internal/data_row_uploader.py | 31 +- .../schema/internal/data_row_upsert_item.py | 49 +- .../internal/descriptor_file_creator.py | 108 +- libs/labelbox/src/labelbox/schema/invite.py | 9 +- libs/labelbox/src/labelbox/schema/label.py | 14 +- .../src/labelbox/schema/labeling_frontend.py | 6 +- .../src/labelbox/schema/labeling_service.py | 26 +- .../schema/labeling_service_dashboard.py | 62 +- .../schema/labeling_service_status.py | 18 +- .../src/labelbox/schema/media_type.py | 31 +- libs/labelbox/src/labelbox/schema/model.py | 33 +- .../src/labelbox/schema/model_config.py | 4 +- .../labelbox/src/labelbox/schema/model_run.py | 467 ++--- libs/labelbox/src/labelbox/schema/ontology.py | 292 +-- .../src/labelbox/schema/ontology_kind.py | 76 +- .../src/labelbox/schema/organization.py | 120 +- libs/labelbox/src/labelbox/schema/project.py | 889 ++++++---- .../labelbox/schema/project_model_config.py | 16 +- .../src/labelbox/schema/project_overview.py | 17 +- .../labelbox/schema/project_resource_tag.py | 2 +- .../src/labelbox/schema/resource_tag.py | 2 +- libs/labelbox/src/labelbox/schema/review.py | 6 +- libs/labelbox/src/labelbox/schema/role.py | 14 +- .../src/labelbox/schema/search_filters.py | 147 +- .../schema/send_to_annotate_params.py | 61 +- .../src/labelbox/schema/serialization.py | 5 +- libs/labelbox/src/labelbox/schema/slice.py | 182 +- libs/labelbox/src/labelbox/schema/task.py | 167 +- libs/labelbox/src/labelbox/schema/user.py | 40 +- .../src/labelbox/schema/user_group.py | 96 +- libs/labelbox/src/labelbox/schema/webhook.py | 37 +- libs/labelbox/src/labelbox/types.py | 2 +- libs/labelbox/src/labelbox/typing_imports.py | 5 +- libs/labelbox/src/labelbox/utils.py | 31 +- libs/labelbox/tests/conftest.py | 604 ++++--- .../tests/data/annotation_import/conftest.py | 1580 ++++++++--------- .../test_annotation_import_limit.py | 57 +- .../test_bulk_import_request.py | 143 +- .../data/annotation_import/test_data_types.py | 24 +- .../test_generic_data_types.py | 233 ++- .../annotation_import/test_label_import.py | 108 +- .../test_mal_prediction_import.py | 49 +- .../test_mea_prediction_import.py | 227 ++- .../data/annotation_import/test_model_run.py | 87 +- .../test_ndjson_validation.py | 157 +- .../test_send_to_annotate_mea.py | 44 +- .../test_upsert_prediction_import.py | 101 +- .../classification/test_classification.py | 190 +- .../data/annotation_types/data/test_raster.py | 12 +- .../data/annotation_types/data/test_text.py | 20 +- .../data/annotation_types/data/test_video.py | 14 +- .../annotation_types/geometry/test_line.py | 2 +- .../annotation_types/geometry/test_mask.py | 143 +- .../annotation_types/geometry/test_point.py | 2 +- .../annotation_types/geometry/test_polygon.py | 8 +- .../geometry/test_rectangle.py | 6 +- .../data/annotation_types/test_annotation.py | 57 +- .../data/annotation_types/test_collection.py | 69 +- .../tests/data/annotation_types/test_label.py | 274 +-- .../data/annotation_types/test_metrics.py | 242 +-- .../tests/data/annotation_types/test_ner.py | 18 +- .../data/annotation_types/test_tiled_image.py | 68 +- .../tests/data/annotation_types/test_video.py | 15 +- libs/labelbox/tests/data/conftest.py | 47 +- libs/labelbox/tests/data/export/conftest.py | 568 +++--- .../data/export/legacy/test_export_catalog.py | 10 +- .../data/export/legacy/test_export_dataset.py | 26 +- .../export/legacy/test_export_model_run.py | 23 +- .../data/export/legacy/test_export_project.py | 181 +- .../data/export/legacy/test_export_slice.py | 10 +- .../data/export/legacy/test_export_video.py | 275 +-- .../data/export/legacy/test_legacy_export.py | 179 +- .../test_export_data_rows_streamable.py | 86 +- .../test_export_dataset_streamable.py | 68 +- .../test_export_embeddings_streamable.py | 74 +- .../test_export_model_run_streamable.py | 28 +- .../test_export_project_streamable.py | 208 ++- .../test_export_video_streamable.py | 108 +- .../data/metrics/confusion_matrix/conftest.py | 598 ++++--- .../test_confusion_matrix_data_row.py | 63 +- .../test_confusion_matrix_feature.py | 54 +- .../data/metrics/iou/data_row/conftest.py | 1262 ++++++------- .../data/metrics/iou/feature/conftest.py | 301 ++-- .../metrics/iou/feature/test_feature_iou.py | 3 +- .../data/serialization/coco/test_coco.py | 26 +- .../serialization/ndjson/test_checklist.py | 408 +++-- .../ndjson/test_classification.py | 10 +- .../serialization/ndjson/test_conversation.py | 194 +- .../serialization/ndjson/test_data_gen.py | 54 +- .../data/serialization/ndjson/test_dicom.py | 197 +- .../serialization/ndjson/test_document.py | 56 +- .../ndjson/test_export_video_objects.py | 1140 ++++++------ .../serialization/ndjson/test_free_text.py | 113 +- .../serialization/ndjson/test_global_key.py | 33 +- .../data/serialization/ndjson/test_image.py | 91 +- .../data/serialization/ndjson/test_metric.py | 22 +- .../data/serialization/ndjson/test_mmc.py | 13 +- .../ndjson/test_ndlabel_subclass_matching.py | 12 +- .../data/serialization/ndjson/test_nested.py | 7 +- .../serialization/ndjson/test_polyline.py | 13 +- .../data/serialization/ndjson/test_radio.py | 104 +- .../serialization/ndjson/test_rectangle.py | 48 +- .../serialization/ndjson/test_relationship.py | 12 +- .../data/serialization/ndjson/test_text.py | 41 +- .../serialization/ndjson/test_text_entity.py | 13 +- .../data/serialization/ndjson/test_video.py | 760 ++++---- .../tests/data/test_data_row_metadata.py | 287 +-- .../tests/data/test_prefetch_generator.py | 3 +- libs/labelbox/tests/integration/conftest.py | 576 +++--- .../integration/schema/test_user_group.py | 16 +- libs/labelbox/tests/integration/test_batch.py | 189 +- .../tests/integration/test_batches.py | 12 +- .../test_chat_evaluation_ontology_project.py | 110 +- .../tests/integration/test_client_errors.py | 13 +- .../test_data_row_delete_metadata.py | 218 ++- .../tests/integration/test_data_rows.py | 855 +++++---- .../integration/test_data_rows_upsert.py | 309 ++-- .../tests/integration/test_dataset.py | 78 +- .../integration/test_delegated_access.py | 108 +- .../tests/integration/test_embedding.py | 21 +- .../tests/integration/test_ephemeral.py | 12 +- .../tests/integration/test_feature_schema.py | 56 +- .../tests/integration/test_filtering.py | 11 +- .../tests/integration/test_foundry.py | 103 +- .../tests/integration/test_global_keys.py | 236 ++- libs/labelbox/tests/integration/test_label.py | 16 +- .../integration/test_labeling_dashboard.py | 82 +- .../integration/test_labeling_frontend.py | 6 +- .../test_labeling_parameter_overrides.py | 47 +- .../integration/test_labeling_service.py | 43 +- .../tests/integration/test_legacy_project.py | 15 +- .../tests/integration/test_model_config.py | 12 +- .../test_offline_chat_evaluation_project.py | 13 +- .../tests/integration/test_ontology.py | 251 +-- .../tests/integration/test_project.py | 119 +- .../integration/test_project_model_config.py | 63 +- .../test_project_set_model_setup_complete.py | 34 +- .../tests/integration/test_project_setup.py | 32 +- ...test_prompt_response_generation_project.py | 162 +- .../test_response_creation_project.py | 18 +- .../integration/test_send_to_annotate.py | 39 +- libs/labelbox/tests/integration/test_task.py | 68 +- .../tests/integration/test_task_queue.py | 30 +- .../tests/integration/test_user_and_org.py | 2 +- .../tests/integration/test_user_management.py | 107 +- .../tests/integration/test_webhook.py | 27 +- libs/labelbox/tests/unit/conftest.py | 145 +- .../unit/export_task/test_export_task.py | 105 +- .../export_task/test_unit_file_converter.py | 11 +- .../test_unit_file_retriever_by_line.py | 55 +- .../test_unit_file_retriever_by_offset.py | 37 +- .../export_task/test_unit_json_converter.py | 18 +- .../tests/unit/schema/test_user_group.py | 197 +- .../tests/unit/test_annotation_import.py | 48 +- .../tests/unit/test_data_row_upsert_data.py | 119 +- libs/labelbox/tests/unit/test_exceptions.py | 21 +- .../tests/unit/test_label_data_type.py | 37 +- libs/labelbox/tests/unit/test_mal_import.py | 85 +- .../tests/unit/test_ndjson_parsing.py | 2 +- libs/labelbox/tests/unit/test_project.py | 24 +- ...est_unit_delete_batch_data_row_metadata.py | 46 +- .../unit/test_unit_descriptor_file_creator.py | 35 +- .../tests/unit/test_unit_entity_meta.py | 14 +- .../tests/unit/test_unit_export_filters.py | 43 +- .../tests/unit/test_unit_label_import.py | 45 +- .../labelbox/tests/unit/test_unit_ontology.py | 306 ++-- .../tests/unit/test_unit_ontology_kind.py | 15 +- ...t_validate_labeling_parameter_overrides.py | 14 +- libs/labelbox/tests/unit/test_unit_query.py | 15 +- .../tests/unit/test_unit_search_filters.py | 155 +- libs/labelbox/tests/unit/test_unit_webhook.py | 9 +- libs/labelbox/tests/unit/test_utils.py | 35 +- libs/labelbox/tests/utils.py | 6 +- 271 files changed, 16955 insertions(+), 13067 deletions(-) diff --git a/libs/labelbox/src/labelbox/__init__.py b/libs/labelbox/src/labelbox/__init__.py index ac7efdc96..633e8f4c2 100644 --- a/libs/labelbox/src/labelbox/__init__.py +++ b/libs/labelbox/src/labelbox/__init__.py @@ -7,7 +7,12 @@ from labelbox.schema.model import Model from labelbox.schema.model_config import ModelConfig from labelbox.schema.bulk_import_request import BulkImportRequest -from labelbox.schema.annotation_import import MALPredictionImport, MEAPredictionImport, LabelImport, MEAToMALPredictionImport +from labelbox.schema.annotation_import import ( + MALPredictionImport, + MEAPredictionImport, + LabelImport, + MEAToMALPredictionImport, +) from labelbox.schema.dataset import Dataset from labelbox.schema.data_row import DataRow from labelbox.schema.catalog import Catalog @@ -18,16 +23,39 @@ from labelbox.schema.user import User from labelbox.schema.organization import Organization from labelbox.schema.task import Task -from labelbox.schema.export_task import StreamType, ExportTask, JsonConverter, JsonConverterOutput, FileConverter, FileConverterOutput, BufferedJsonConverterOutput -from labelbox.schema.labeling_frontend import LabelingFrontend, LabelingFrontendOptions +from labelbox.schema.export_task import ( + StreamType, + ExportTask, + JsonConverter, + JsonConverterOutput, + FileConverter, + FileConverterOutput, + BufferedJsonConverterOutput, +) +from labelbox.schema.labeling_frontend import ( + LabelingFrontend, + LabelingFrontendOptions, +) from labelbox.schema.asset_attachment import AssetAttachment from labelbox.schema.webhook import Webhook -from labelbox.schema.ontology import Ontology, OntologyBuilder, Classification, Option, Tool, FeatureSchema +from labelbox.schema.ontology import ( + Ontology, + OntologyBuilder, + Classification, + Option, + Tool, + FeatureSchema, +) from labelbox.schema.ontology import PromptResponseClassification from labelbox.schema.ontology import ResponseOption from labelbox.schema.role import Role, ProjectRole from labelbox.schema.invite import Invite, InviteLimit -from labelbox.schema.data_row_metadata import DataRowMetadataOntology, DataRowMetadataField, DataRowMetadata, DeleteDataRowMetadata +from labelbox.schema.data_row_metadata import ( + DataRowMetadataOntology, + DataRowMetadataField, + DataRowMetadata, + DeleteDataRowMetadata, +) from labelbox.schema.model_run import ModelRun, DataSplit from labelbox.schema.benchmark import Benchmark from labelbox.schema.iam_integration import IAMIntegration @@ -42,7 +70,10 @@ from labelbox.schema.identifiables import UniqueIds, GlobalKeys, DataRowIds from labelbox.schema.identifiable import UniqueId, GlobalKey from labelbox.schema.ontology_kind import OntologyKind -from labelbox.schema.project_overview import ProjectOverview, ProjectOverviewDetailed +from labelbox.schema.project_overview import ( + ProjectOverview, + ProjectOverviewDetailed, +) from labelbox.schema.labeling_service import LabelingService from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard from labelbox.schema.labeling_service_status import LabelingServiceStatus diff --git a/libs/labelbox/src/labelbox/adv_client.py b/libs/labelbox/src/labelbox/adv_client.py index 6eab78d68..626ac0279 100644 --- a/libs/labelbox/src/labelbox/adv_client.py +++ b/libs/labelbox/src/labelbox/adv_client.py @@ -12,7 +12,6 @@ class AdvClient: - def __init__(self, endpoint: str, api_key: str): self.endpoint = endpoint self.api_key = api_key @@ -32,8 +31,9 @@ def get_embeddings(self) -> List[Dict[str, Any]]: return self._request("GET", "/adv/v1/embeddings").get("results", []) def import_vectors_from_file(self, id: str, file_path: str, callback=None): - self._send_ndjson(f"/adv/v1/embeddings/{id}/_import_ndjson", file_path, - callback) + self._send_ndjson( + f"/adv/v1/embeddings/{id}/_import_ndjson", file_path, callback + ) def get_imported_vector_count(self, id: str) -> int: data = self._request("GET", f"/adv/v1/embeddings/{id}/vectors/_count") @@ -41,38 +41,42 @@ def get_imported_vector_count(self, id: str) -> int: def _create_session(self) -> Session: session = requests.session() - session.headers.update({ - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - }) + session.headers.update( + { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + ) return session - def _request(self, - method: str, - path: str, - data: Optional[Dict[str, Any]] = None, - headers: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + def _request( + self, + method: str, + path: str, + data: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: url = f"{self.endpoint}{path}" requests_data = None if data: requests_data = json.dumps(data) - response = self.session.request(method, - url, - data=requests_data, - headers=headers) + response = self.session.request( + method, url, data=requests_data, headers=headers + ) if response.status_code != requests.codes.ok: - message = response.json().get('message') + message = response.json().get("message") if message: raise LabelboxError(message) else: response.raise_for_status() return response.json() - def _send_ndjson(self, - path: str, - file_path: str, - callback: Optional[Callable[[Dict[str, Any]], - None]] = None): + def _send_ndjson( + self, + path: str, + file_path: str, + callback: Optional[Callable[[Dict[str, Any]], None]] = None, + ): """ Sends an NDJson file in chunks. @@ -87,7 +91,7 @@ def upload_chunk(_buffer, _count): _headers = { "Content-Type": "application/x-ndjson", "X-Content-Lines": str(_count), - "Content-Length": str(buffer.tell()) + "Content-Length": str(buffer.tell()), } rsp = self._send_bytes(f"{self.endpoint}{path}", _buffer, _headers) rsp.raise_for_status() @@ -96,7 +100,7 @@ def upload_chunk(_buffer, _count): buffer = io.BytesIO() count = 0 - with open(file_path, 'rb') as fp: + with open(file_path, "rb") as fp: for line in fp: buffer.write(line) count += 1 @@ -107,10 +111,12 @@ def upload_chunk(_buffer, _count): if count: upload_chunk(buffer, count) - def _send_bytes(self, - url: str, - buffer: io.BytesIO, - headers: Optional[Dict[str, Any]] = None) -> Response: + def _send_bytes( + self, + url: str, + buffer: io.BytesIO, + headers: Optional[Dict[str, Any]] = None, + ) -> Response: buffer.seek(0) return self.session.put(url, headers=headers, data=buffer) diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index 431ddbdc4..cda55c282 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -26,7 +26,9 @@ from labelbox.orm.model import Entity, Field from labelbox.pagination import PaginatedCollection from labelbox.schema import role -from labelbox.schema.conflict_resolution_strategy import ConflictResolutionStrategy +from labelbox.schema.conflict_resolution_strategy import ( + ConflictResolutionStrategy, +) from labelbox.schema.data_row import DataRow from labelbox.schema.catalog import Catalog from labelbox.schema.data_row_metadata import DataRowMetadataOntology @@ -38,26 +40,46 @@ from labelbox.schema.identifiables import DataRowIds from labelbox.schema.identifiables import GlobalKeys from labelbox.schema.labeling_frontend import LabelingFrontend -from labelbox.schema.media_type import MediaType, get_media_type_validation_error +from labelbox.schema.media_type import ( + MediaType, + get_media_type_validation_error, +) from labelbox.schema.model import Model from labelbox.schema.model_config import ModelConfig from labelbox.schema.model_run import ModelRun from labelbox.schema.ontology import Ontology, DeleteFeatureFromOntologyResult -from labelbox.schema.ontology import Tool, Classification, FeatureSchema, PromptResponseClassification +from labelbox.schema.ontology import ( + Tool, + Classification, + FeatureSchema, + PromptResponseClassification, +) from labelbox.schema.organization import Organization from labelbox.schema.project import Project -from labelbox.schema.quality_mode import QualityMode, BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS, \ - BENCHMARK_AUTO_AUDIT_PERCENTAGE, CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS, CONSENSUS_AUTO_AUDIT_PERCENTAGE +from labelbox.schema.quality_mode import ( + QualityMode, + BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS, + BENCHMARK_AUTO_AUDIT_PERCENTAGE, + CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS, + CONSENSUS_AUTO_AUDIT_PERCENTAGE, +) from labelbox.schema.queue_mode import QueueMode from labelbox.schema.role import Role -from labelbox.schema.send_to_annotate_params import SendToAnnotateFromCatalogParams, build_destination_task_queue_input, \ - build_predictions_input, build_annotations_input +from labelbox.schema.send_to_annotate_params import ( + SendToAnnotateFromCatalogParams, + build_destination_task_queue_input, + build_predictions_input, + build_annotations_input, +) from labelbox.schema.slice import CatalogSlice, ModelSlice from labelbox.schema.task import Task, DataUpsertTask from labelbox.schema.user import User from labelbox.schema.label_score import LabelScore -from labelbox.schema.ontology_kind import (OntologyKind, EditorTaskTypeMapper, - EditorTaskType) +from labelbox.schema.ontology_kind import ( + OntologyKind, + EditorTaskTypeMapper, + EditorTaskType, +) from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard logger = logging.getLogger(__name__) @@ -72,20 +94,22 @@ def python_version_info(): class Client: - """ A Labelbox client. + """A Labelbox client. Contains info necessary for connecting to a Labelbox server (URL, authentication key). Provides functions for querying and creating top-level data objects (Projects, Datasets). """ - def __init__(self, - api_key=None, - endpoint='https://api.labelbox.com/graphql', - enable_experimental=False, - app_url="https://app.labelbox.com", - rest_endpoint="https://api.labelbox.com/api/v1"): - """ Creates and initializes a Labelbox Client. + def __init__( + self, + api_key=None, + endpoint="https://api.labelbox.com/graphql", + enable_experimental=False, + app_url="https://app.labelbox.com", + rest_endpoint="https://api.labelbox.com/api/v1", + ): + """Creates and initializes a Labelbox Client. Logging is defaulted to level WARNING. To receive more verbose output to console, update `logging.level` to the appropriate level. @@ -106,7 +130,8 @@ def __init__(self, if api_key is None: if _LABELBOX_API_KEY not in os.environ: raise labelbox.exceptions.AuthenticationError( - "Labelbox API key not provided") + "Labelbox API key not provided" + ) api_key = os.environ[_LABELBOX_API_KEY] self.api_key = api_key @@ -123,7 +148,8 @@ def __init__(self, self._connection: requests.Session = self._init_connection() def _init_connection(self) -> requests.Session: - connection = requests.Session( + connection = ( + requests.Session() ) # using default connection pool size of 10 connection.headers.update(self._default_headers()) @@ -135,26 +161,31 @@ def headers(self) -> MappingProxyType: def _default_headers(self): return { - 'Authorization': 'Bearer %s' % self.api_key, - 'Accept': 'application/json', - 'Content-Type': 'application/json', - 'X-User-Agent': f"python-sdk {SDK_VERSION}", - 'X-Python-Version': f"{python_version_info()}", + "Authorization": "Bearer %s" % self.api_key, + "Accept": "application/json", + "Content-Type": "application/json", + "X-User-Agent": f"python-sdk {SDK_VERSION}", + "X-Python-Version": f"{python_version_info()}", } - @retry.Retry(predicate=retry.if_exception_type( - labelbox.exceptions.InternalServerError, - labelbox.exceptions.TimeoutError)) - def execute(self, - query=None, - params=None, - data=None, - files=None, - timeout=60.0, - experimental=False, - error_log_key="message", - raise_return_resource_not_found=False): - """ Sends a request to the server for the execution of the + @retry.Retry( + predicate=retry.if_exception_type( + labelbox.exceptions.InternalServerError, + labelbox.exceptions.TimeoutError, + ) + ) + def execute( + self, + query=None, + params=None, + data=None, + files=None, + timeout=60.0, + experimental=False, + error_log_key="message", + raise_return_resource_not_found=False, + ): + """Sends a request to the server for the execution of the given query. Checks the response for errors and wraps errors @@ -199,26 +230,30 @@ def convert_value(value): params = { key: convert_value(value) for key, value in params.items() } - data = json.dumps({ - 'query': query, - 'variables': params - }).encode('utf-8') + data = json.dumps({"query": query, "variables": params}).encode( + "utf-8" + ) elif data is None: raise ValueError("query and data cannot both be none") - endpoint = self.endpoint if not experimental else self.endpoint.replace( - "/graphql", "/_gql") + endpoint = ( + self.endpoint + if not experimental + else self.endpoint.replace("/graphql", "/_gql") + ) try: headers = self._connection.headers.copy() if files: - del headers['Content-Type'] - del headers['Accept'] - request = requests.Request('POST', - endpoint, - headers=headers, - data=data, - files=files if files else None) + del headers["Content-Type"] + del headers["Accept"] + request = requests.Request( + "POST", + endpoint, + headers=headers, + data=data, + files=files if files else None, + ) prepped: requests.PreparedRequest = request.prepare() @@ -231,20 +266,30 @@ def convert_value(value): raise labelbox.exceptions.NetworkError(e) except Exception as e: raise labelbox.exceptions.LabelboxError( - "Unknown error during Client.query(): " + str(e), e) + "Unknown error during Client.query(): " + str(e), e + ) - if 200 <= response.status_code < 300 or response.status_code < 500 or response.status_code >= 600: + if ( + 200 <= response.status_code < 300 + or response.status_code < 500 + or response.status_code >= 600 + ): try: r_json = response.json() except Exception: raise labelbox.exceptions.LabelboxError( - "Failed to parse response as JSON: %s" % response.text) + "Failed to parse response as JSON: %s" % response.text + ) else: - if "upstream connect error or disconnect/reset before headers" in response.text: + if ( + "upstream connect error or disconnect/reset before headers" + in response.text + ): raise labelbox.exceptions.InternalServerError( - "Connection reset") + "Connection reset" + ) elif response.status_code == 502: - error_502 = '502 Bad Gateway' + error_502 = "502 Bad Gateway" raise labelbox.exceptions.InternalServerError(error_502) elif 500 <= response.status_code < 600: error_500 = f"Internal server http error {response.status_code}" @@ -253,7 +298,7 @@ def convert_value(value): errors = r_json.get("errors", []) def check_errors(keywords, *path): - """ Helper that looks for any of the given `keywords` in any of + """Helper that looks for any of the given `keywords` in any of current errors on paths (like error[path][component][to][keyword]). """ for error in errors: @@ -270,18 +315,23 @@ def get_error_status_code(error: dict) -> int: except: return 500 - if check_errors(["AUTHENTICATION_ERROR"], "extensions", - "code") is not None: + if ( + check_errors(["AUTHENTICATION_ERROR"], "extensions", "code") + is not None + ): raise labelbox.exceptions.AuthenticationError("Invalid API key") - authorization_error = check_errors(["AUTHORIZATION_ERROR"], - "extensions", "code") + authorization_error = check_errors( + ["AUTHORIZATION_ERROR"], "extensions", "code" + ) if authorization_error is not None: raise labelbox.exceptions.AuthorizationError( - authorization_error["message"]) + authorization_error["message"] + ) - validation_error = check_errors(["GRAPHQL_VALIDATION_FAILED"], - "extensions", "code") + validation_error = check_errors( + ["GRAPHQL_VALIDATION_FAILED"], "extensions", "code" + ) if validation_error is not None: message = validation_error["message"] @@ -290,11 +340,13 @@ def get_error_status_code(error: dict) -> int: else: raise labelbox.exceptions.InvalidQueryError(message) - graphql_error = check_errors(["GRAPHQL_PARSE_FAILED"], "extensions", - "code") + graphql_error = check_errors( + ["GRAPHQL_PARSE_FAILED"], "extensions", "code" + ) if graphql_error is not None: raise labelbox.exceptions.InvalidQueryError( - graphql_error["message"]) + graphql_error["message"] + ) # Check if API limit was exceeded response_msg = r_json.get("message", "") @@ -302,34 +354,41 @@ def get_error_status_code(error: dict) -> int: if response_msg.startswith("You have exceeded"): raise labelbox.exceptions.ApiLimitError(response_msg) - resource_not_found_error = check_errors(["RESOURCE_NOT_FOUND"], - "extensions", "code") + resource_not_found_error = check_errors( + ["RESOURCE_NOT_FOUND"], "extensions", "code" + ) if resource_not_found_error is not None: if raise_return_resource_not_found: raise labelbox.exceptions.ResourceNotFoundError( - message=resource_not_found_error["message"]) + message=resource_not_found_error["message"] + ) else: # Return None and let the caller methods raise an exception # as they already know which resource type and ID was requested return None - resource_conflict_error = check_errors(["RESOURCE_CONFLICT"], - "extensions", "code") + resource_conflict_error = check_errors( + ["RESOURCE_CONFLICT"], "extensions", "code" + ) if resource_conflict_error is not None: raise labelbox.exceptions.ResourceConflict( - resource_conflict_error["message"]) + resource_conflict_error["message"] + ) - malformed_request_error = check_errors(["MALFORMED_REQUEST"], - "extensions", "code") + malformed_request_error = check_errors( + ["MALFORMED_REQUEST"], "extensions", "code" + ) if malformed_request_error is not None: raise labelbox.exceptions.MalformedQueryException( - malformed_request_error[error_log_key]) + malformed_request_error[error_log_key] + ) # A lot of different error situations are now labeled serverside # as INTERNAL_SERVER_ERROR, when they are actually client errors. # TODO: fix this in the server API - internal_server_error = check_errors(["INTERNAL_SERVER_ERROR"], - "extensions", "code") + internal_server_error = check_errors( + ["INTERNAL_SERVER_ERROR"], "extensions", "code" + ) if internal_server_error is not None: message = internal_server_error.get("message") error_status_code = get_error_status_code(internal_server_error) @@ -344,8 +403,9 @@ def get_error_status_code(error: dict) -> int: else: raise labelbox.exceptions.InternalServerError(message) - not_allowed_error = check_errors(["OPERATION_NOT_ALLOWED"], - "extensions", "code") + not_allowed_error = check_errors( + ["OPERATION_NOT_ALLOWED"], "extensions", "code" + ) if not_allowed_error is not None: message = not_allowed_error.get("message") raise labelbox.exceptions.OperationNotAllowedException(message) @@ -356,10 +416,14 @@ def get_error_status_code(error: dict) -> int: map( lambda x: { "message": x["message"], - "code": x["extensions"]["code"] - }, errors)) - raise labelbox.exceptions.LabelboxError("Unknown error: %s" % - str(messages)) + "code": x["extensions"]["code"], + }, + errors, + ) + ) + raise labelbox.exceptions.LabelboxError( + "Unknown error: %s" % str(messages) + ) # if we do return a proper error code, and didn't catch this above # reraise @@ -368,7 +432,7 @@ def get_error_status_code(error: dict) -> int: # in the SDK if response.status_code != requests.codes.ok: message = f"{response.status_code} {response.reason}" - cause = r_json.get('message') + cause = r_json.get("message") raise labelbox.exceptions.LabelboxError(message, cause) return r_json["data"] @@ -388,18 +452,23 @@ def upload_file(self, path: str) -> str: content_type, _ = mimetypes.guess_type(path) filename = os.path.basename(path) with open(path, "rb") as f: - return self.upload_data(content=f.read(), - filename=filename, - content_type=content_type) - - @retry.Retry(predicate=retry.if_exception_type( - labelbox.exceptions.InternalServerError)) - def upload_data(self, - content: bytes, - filename: str = None, - content_type: str = None, - sign: bool = False) -> str: - """ Uploads the given data (bytes) to Labelbox. + return self.upload_data( + content=f.read(), filename=filename, content_type=content_type + ) + + @retry.Retry( + predicate=retry.if_exception_type( + labelbox.exceptions.InternalServerError + ) + ) + def upload_data( + self, + content: bytes, + filename: str = None, + content_type: str = None, + sign: bool = False, + ) -> str: + """Uploads the given data (bytes) to Labelbox. Args: content: bytestring to upload @@ -415,40 +484,43 @@ def upload_data(self, """ request_data = { - "operations": - json.dumps({ + "operations": json.dumps( + { "variables": { "file": None, "contentLength": len(content), - "sign": sign + "sign": sign, }, - "query": - """mutation UploadFile($file: Upload!, $contentLength: Int!, + "query": """mutation UploadFile($file: Upload!, $contentLength: Int!, $sign: Boolean) { uploadFile(file: $file, contentLength: $contentLength, sign: $sign) {url filename} } """, - }), + } + ), "map": (None, json.dumps({"1": ["variables.file"]})), } files = { - "1": (filename, content, content_type) if - (filename and content_type) else content + "1": (filename, content, content_type) + if (filename and content_type) + else content } headers = self._connection.headers.copy() headers.pop("Content-Type", None) - request = requests.Request('POST', - self.endpoint, - headers=headers, - data=request_data, - files=files) + request = requests.Request( + "POST", + self.endpoint, + headers=headers, + data=request_data, + files=files, + ) prepped: requests.PreparedRequest = request.prepare() response = self._connection.send(prepped) if response.status_code == 502: - error_502 = '502 Bad Gateway' + error_502 = "502 Bad Gateway" raise labelbox.exceptions.InternalServerError(error_502) elif response.status_code == 503: raise labelbox.exceptions.InternalServerError(response.text) @@ -459,22 +531,25 @@ def upload_data(self, file_data = response.json().get("data", None) except ValueError as e: # response is not valid JSON raise labelbox.exceptions.LabelboxError( - "Failed to upload, unknown cause", e) + "Failed to upload, unknown cause", e + ) if not file_data or not file_data.get("uploadFile", None): try: errors = response.json().get("errors", []) - error_msg = next(iter(errors), {}).get("message", - "Unknown error") + error_msg = next(iter(errors), {}).get( + "message", "Unknown error" + ) except Exception as e: error_msg = "Unknown error" raise labelbox.exceptions.LabelboxError( - "Failed to upload, message: %s" % error_msg) + "Failed to upload, message: %s" % error_msg + ) return file_data["uploadFile"]["url"] def _get_single(self, db_object_type, uid): - """ Fetches a single object of the given type, for the given ID. + """Fetches a single object of the given type, for the given ID. Args: db_object_type (type): DbObject subclass. @@ -491,12 +566,13 @@ def _get_single(self, db_object_type, uid): res = res and res.get(utils.camel_case(db_object_type.type_name())) if res is None: raise labelbox.exceptions.ResourceNotFoundError( - db_object_type, params) + db_object_type, params + ) else: return db_object_type(self, res) def get_project(self, project_id) -> Project: - """ Gets a single Project with the given ID. + """Gets a single Project with the given ID. >>> project = client.get_project("") @@ -511,7 +587,7 @@ def get_project(self, project_id) -> Project: return self._get_single(Entity.Project, project_id) def get_dataset(self, dataset_id) -> Dataset: - """ Gets a single Dataset with the given ID. + """Gets a single Dataset with the given ID. >>> dataset = client.get_dataset("") @@ -526,21 +602,21 @@ def get_dataset(self, dataset_id) -> Dataset: return self._get_single(Entity.Dataset, dataset_id) def get_user(self) -> User: - """ Gets the current User database object. + """Gets the current User database object. >>> user = client.get_user() """ return self._get_single(Entity.User, None) def get_organization(self) -> Organization: - """ Gets the Organization DB object of the current user. + """Gets the Organization DB object of the current user. >>> organization = client.get_organization() """ return self._get_single(Entity.Organization, None) def _get_all(self, db_object_type, where, filter_deleted=True): - """ Fetches all the objects of the given type the user has access to. + """Fetches all the objects of the given type the user has access to. Args: db_object_type (type): DbObject subclass. @@ -555,12 +631,15 @@ def _get_all(self, db_object_type, where, filter_deleted=True): query_str, params = query.get_all(db_object_type, where) return PaginatedCollection( - self, query_str, params, + self, + query_str, + params, [utils.camel_case(db_object_type.type_name()) + "s"], - db_object_type) + db_object_type, + ) def get_projects(self, where=None) -> PaginatedCollection: - """ Fetches all the projects the user has access to. + """Fetches all the projects the user has access to. >>> projects = client.get_projects(where=(Project.name == "") & (Project.description == "")) @@ -573,7 +652,7 @@ def get_projects(self, where=None) -> PaginatedCollection: return self._get_all(Entity.Project, where) def get_users(self, where=None) -> PaginatedCollection: - """ Fetches all the users. + """Fetches all the users. >>> users = client.get_users(where=User.email == "") @@ -586,7 +665,7 @@ def get_users(self, where=None) -> PaginatedCollection: return self._get_all(Entity.User, where, filter_deleted=False) def get_datasets(self, where=None) -> PaginatedCollection: - """ Fetches one or more datasets. + """Fetches one or more datasets. >>> datasets = client.get_datasets(where=(Dataset.name == "") & (Dataset.description == "")) @@ -599,7 +678,7 @@ def get_datasets(self, where=None) -> PaginatedCollection: return self._get_all(Entity.Dataset, where) def get_labeling_frontends(self, where=None) -> List[LabelingFrontend]: - """ Fetches all the labeling frontends. + """Fetches all the labeling frontends. >>> frontend = client.get_labeling_frontends(where=LabelingFrontend.name == "Editor") @@ -612,7 +691,7 @@ def get_labeling_frontends(self, where=None) -> List[LabelingFrontend]: return self._get_all(Entity.LabelingFrontend, where) def _create(self, db_object_type, data, extra_params={}): - """ Creates an object on the server. Attribute values are + """Creates an object on the server. Attribute values are passed as keyword arguments: Args: @@ -630,8 +709,9 @@ def _create(self, db_object_type, data, extra_params={}): # Convert string attribute names to Field or Relationship objects. # Also convert Labelbox object values to their UIDs. data = { - db_object_type.attribute(attr) if isinstance(attr, str) else attr: - value.uid if isinstance(value, DbObject) else value + db_object_type.attribute(attr) + if isinstance(attr, str) + else attr: value.uid if isinstance(value, DbObject) else value for attr, value in data.items() } @@ -640,15 +720,17 @@ def _create(self, db_object_type, data, extra_params={}): res = self.execute(query_string, params) if not res: - raise labelbox.exceptions.LabelboxError("Failed to create %s" % - db_object_type.type_name()) + raise labelbox.exceptions.LabelboxError( + "Failed to create %s" % db_object_type.type_name() + ) res = res["create%s" % db_object_type.type_name()] return db_object_type(self, res) - def create_model_config(self, name: str, model_id: str, - inference_params: dict) -> ModelConfig: - """ Creates a new model config with the given params. + def create_model_config( + self, name: str, model_id: str, inference_params: dict + ) -> ModelConfig: + """Creates a new model config with the given params. Model configs are scoped to organizations, and can be reused between projects. Args: @@ -673,13 +755,13 @@ def create_model_config(self, name: str, model_id: str, params = { "modelId": model_id, "inferenceParams": inference_params, - "name": name + "name": name, } result = self.execute(query, params) - return ModelConfig(self, result['createModelConfig']) + return ModelConfig(self, result["createModelConfig"]) def delete_model_config(self, id: str) -> bool: - """ Deletes an existing model config with the given id + """Deletes an existing model config with the given id Args: id (str): ID of existing model config @@ -697,13 +779,14 @@ def delete_model_config(self, id: str) -> bool: result = self.execute(query, params) if not result: raise labelbox.exceptions.ResourceNotFoundError( - Entity.ModelConfig, params) - return result['deleteModelConfig']['success'] + Entity.ModelConfig, params + ) + return result["deleteModelConfig"]["success"] - def create_dataset(self, - iam_integration=IAMIntegration._DEFAULT, - **kwargs) -> Dataset: - """ Creates a Dataset object on the server. + def create_dataset( + self, iam_integration=IAMIntegration._DEFAULT, **kwargs + ) -> Dataset: + """Creates a Dataset object on the server. Attribute values are passed as keyword arguments. @@ -724,8 +807,9 @@ def create_dataset(self, """ dataset = self._create(Entity.Dataset, kwargs) if iam_integration == IAMIntegration._DEFAULT: - iam_integration = self.get_organization( - ).get_default_iam_integration() + iam_integration = ( + self.get_organization().get_default_iam_integration() + ) if iam_integration is None: return dataset @@ -738,21 +822,23 @@ def create_dataset(self, if not iam_integration.valid: raise ValueError( - "Integration is not valid. Please select another.") + "Integration is not valid. Please select another." + ) self.execute( """mutation setSignerForDatasetPyApi($signerId: ID!, $datasetId: ID!) { setSignerForDataset(data: { signerId: $signerId}, where: {id: $datasetId}){id}} - """, { - 'signerId': iam_integration.uid, - 'datasetId': dataset.uid - }) + """, + {"signerId": iam_integration.uid, "datasetId": dataset.uid}, + ) validation_result = self.execute( """mutation validateDatasetPyApi($id: ID!){validateDataset(where: {id : $id}){ valid checks{name, success}}} - """, {'id': dataset.uid}) + """, + {"id": dataset.uid}, + ) - if not validation_result['validateDataset']['valid']: + if not validation_result["validateDataset"]["valid"]: raise labelbox.exceptions.LabelboxError( f"IAMIntegration was not successfully added to the dataset." ) @@ -762,7 +848,7 @@ def create_dataset(self, return dataset def create_project(self, **kwargs) -> Project: - """ Creates a Project object on the server. + """Creates a Project object on the server. Attribute values are passed as keyword arguments. @@ -800,26 +886,32 @@ def create_project(self, **kwargs) -> Project: return self._create_project(**kwargs) @overload - def create_model_evaluation_project(self, - dataset_name: str, - dataset_id: str = None, - data_row_count: int = 100, - **kwargs) -> Project: + def create_model_evaluation_project( + self, + dataset_name: str, + dataset_id: str = None, + data_row_count: int = 100, + **kwargs, + ) -> Project: pass @overload - def create_model_evaluation_project(self, - dataset_id: str, - dataset_name: str = None, - data_row_count: int = 100, - **kwargs) -> Project: + def create_model_evaluation_project( + self, + dataset_id: str, + dataset_name: str = None, + data_row_count: int = 100, + **kwargs, + ) -> Project: pass - def create_model_evaluation_project(self, - dataset_id: Optional[str] = None, - dataset_name: Optional[str] = None, - data_row_count: int = 100, - **kwargs) -> Project: + def create_model_evaluation_project( + self, + dataset_id: Optional[str] = None, + dataset_name: Optional[str] = None, + data_row_count: int = 100, + **kwargs, + ) -> Project: """ Use this method exclusively to create a chat model evaluation project. Args: @@ -875,10 +967,12 @@ def create_offline_model_evaluation_project(self, **kwargs) -> Project: Returns: Project: The created project """ - kwargs[ - "media_type"] = MediaType.Conversational # Only Conversational is supported - kwargs[ - "editor_task_type"] = EditorTaskType.OfflineModelChatEvaluation.value # Special editor task type for offline model evaluation + kwargs["media_type"] = ( + MediaType.Conversational + ) # Only Conversational is supported + kwargs["editor_task_type"] = ( + EditorTaskType.OfflineModelChatEvaluation.value + ) # Special editor task type for offline model evaluation # The following arguments are not supported for offline model evaluation kwargs.pop("dataset_name_or_id", None) @@ -888,11 +982,12 @@ def create_offline_model_evaluation_project(self, **kwargs) -> Project: return self._create_project(**kwargs) def create_prompt_response_generation_project( - self, - dataset_id: Optional[str] = None, - dataset_name: Optional[str] = None, - data_row_count: int = 100, - **kwargs) -> Project: + self, + dataset_id: Optional[str] = None, + dataset_name: Optional[str] = None, + data_row_count: int = 100, + **kwargs, + ) -> Project: """ Use this method exclusively to create a prompt and response generation project. @@ -927,7 +1022,8 @@ def create_prompt_response_generation_project( if dataset_id and dataset_name: raise ValueError( - "Only provide a dataset_name or dataset_id, not both.") + "Only provide a dataset_name or dataset_id, not both." + ) if data_row_count <= 0: raise ValueError("data_row_count must be a positive integer.") @@ -940,7 +1036,8 @@ def create_prompt_response_generation_project( dataset_name_or_id = dataset_name if "media_type" in kwargs and kwargs.get("media_type") not in [ - MediaType.LLMPromptCreation, MediaType.LLMPromptResponseCreation + MediaType.LLMPromptCreation, + MediaType.LLMPromptResponseCreation, ]: raise ValueError( "media_type must be either LLMPromptCreation or LLMPromptResponseCreation" @@ -963,8 +1060,9 @@ def create_response_creation_project(self, **kwargs) -> Project: Project: The created project """ kwargs["media_type"] = MediaType.Text # Only Text is supported - kwargs[ - "editor_task_type"] = EditorTaskType.ResponseCreation.value # Special editor task type for response creation projects + kwargs["editor_task_type"] = ( + EditorTaskType.ResponseCreation.value + ) # Special editor task type for response creation projects # The following arguments are not supported for response creation projects kwargs.pop("dataset_name_or_id", None) @@ -976,7 +1074,10 @@ def create_response_creation_project(self, **kwargs) -> Project: def _create_project(self, **kwargs) -> Project: auto_audit_percentage = kwargs.get("auto_audit_percentage") auto_audit_number_of_labels = kwargs.get("auto_audit_number_of_labels") - if auto_audit_percentage is not None or auto_audit_number_of_labels is not None: + if ( + auto_audit_percentage is not None + or auto_audit_number_of_labels is not None + ): raise ValueError( "quality_modes must be set instead of auto_audit_percentage or auto_audit_number_of_labels." ) @@ -999,13 +1100,16 @@ def _create_project(self, **kwargs) -> Project: if media_type and MediaType.is_supported(media_type): media_type_value = media_type.value elif media_type: - raise TypeError(f"{media_type} is not a valid media type. Use" - f" any of {MediaType.get_supported_members()}" - " from MediaType. Example: MediaType.Image.") + raise TypeError( + f"{media_type} is not a valid media type. Use" + f" any of {MediaType.get_supported_members()}" + " from MediaType. Example: MediaType.Image." + ) else: logger.warning( "Creating a project without specifying media_type" - " through this method will soon no longer be supported.") + " through this method will soon no longer be supported." + ) media_type_value = None quality_modes = kwargs.get("quality_modes") @@ -1034,22 +1138,28 @@ def _create_project(self, **kwargs) -> Project: if quality_mode: quality_modes_set = {quality_mode} - if (quality_modes_set is None or len(quality_modes_set) == 0 or - quality_modes_set - == {QualityMode.Benchmark, QualityMode.Consensus}): - data[ - "auto_audit_number_of_labels"] = CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS + if ( + quality_modes_set is None + or len(quality_modes_set) == 0 + or quality_modes_set + == {QualityMode.Benchmark, QualityMode.Consensus} + ): + data["auto_audit_number_of_labels"] = ( + CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS + ) data["auto_audit_percentage"] = CONSENSUS_AUTO_AUDIT_PERCENTAGE data["is_benchmark_enabled"] = True data["is_consensus_enabled"] = True elif quality_modes_set == {QualityMode.Benchmark}: - data[ - "auto_audit_number_of_labels"] = BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS + data["auto_audit_number_of_labels"] = ( + BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS + ) data["auto_audit_percentage"] = BENCHMARK_AUTO_AUDIT_PERCENTAGE data["is_benchmark_enabled"] = True elif quality_modes_set == {QualityMode.Consensus}: - data[ - "auto_audit_number_of_labels"] = CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS + data["auto_audit_number_of_labels"] = ( + CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS + ) data["auto_audit_percentage"] = CONSENSUS_AUTO_AUDIT_PERCENTAGE data["is_consensus_enabled"] = True else: @@ -1062,10 +1172,12 @@ def _create_project(self, **kwargs) -> Project: params["media_type"] = media_type_value extra_params = { - Field.String("dataset_name_or_id"): - params.pop("dataset_name_or_id", None), - Field.Boolean("append_to_existing_dataset"): - params.pop("append_to_existing_dataset", None), + Field.String("dataset_name_or_id"): params.pop( + "dataset_name_or_id", None + ), + Field.Boolean("append_to_existing_dataset"): params.pop( + "append_to_existing_dataset", None + ), } extra_params = {k: v for k, v in extra_params.items() if v is not None} return self._create(Entity.Project, params, extra_params) @@ -1089,13 +1201,14 @@ def get_data_row(self, data_row_id): def get_data_row_by_global_key(self, global_key: str) -> DataRow: """ - Returns: DataRow: returns a single data row given the global key + Returns: DataRow: returns a single data row given the global key """ res = self.get_data_row_ids_for_global_keys([global_key]) - if res['status'] != "SUCCESS": + if res["status"] != "SUCCESS": raise labelbox.exceptions.ResourceNotFoundError( - Entity.DataRow, {global_key: global_key}) - data_row_id = res['results'][0] + Entity.DataRow, {global_key: global_key} + ) + data_row_id = res["results"][0] return self.get_data_row(data_row_id) @@ -1111,7 +1224,7 @@ def get_data_row_metadata_ontology(self) -> DataRowMetadataOntology: return self._data_row_metadata_ontology def get_model(self, model_id) -> Model: - """ Gets a single Model with the given ID. + """Gets a single Model with the given ID. >>> model = client.get_model("") @@ -1126,7 +1239,7 @@ def get_model(self, model_id) -> Model: return self._get_single(Entity.Model, model_id) def get_models(self, where=None) -> List[Model]: - """ Fetches all the models the user has access to. + """Fetches all the models the user has access to. >>> models = client.get_models(where=(Model.name == "")) @@ -1139,7 +1252,7 @@ def get_models(self, where=None) -> List[Model]: return self._get_all(Entity.Model, where, filter_deleted=False) def create_model(self, name, ontology_id) -> Model: - """ Creates a Model object on the server. + """Creates a Model object on the server. >>> model = client.create_model(, ) @@ -1158,14 +1271,14 @@ def create_model(self, name, ontology_id) -> Model: } }""" % query.results_query_part(Entity.Model) - result = self.execute(query_str, { - "name": name, - "ontologyId": ontology_id - }) - return Entity.Model(self, result['createModel']) + result = self.execute( + query_str, {"name": name, "ontologyId": ontology_id} + ) + return Entity.Model(self, result["createModel"]) def get_data_row_ids_for_external_ids( - self, external_ids: List[str]) -> Dict[str, List[str]]: + self, external_ids: List[str] + ) -> Dict[str, List[str]]: """ Returns a list of data row ids for a list of external ids. There is a max of 1500 items returned at a time. @@ -1183,10 +1296,10 @@ def get_data_row_ids_for_external_ids( result = defaultdict(list) for i in range(0, len(external_ids), max_ids_per_request): for row in self.execute( - query_str, - {'externalId_in': external_ids[i:i + max_ids_per_request] - })['externalIdsToDataRowIds']: - result[row['externalId']].append(row['dataRowId']) + query_str, + {"externalId_in": external_ids[i : i + max_ids_per_request]}, + )["externalIdsToDataRowIds"]: + result[row["externalId"]].append(row["dataRowId"]) return result def get_ontology(self, ontology_id) -> Ontology: @@ -1216,10 +1329,15 @@ def get_ontologies(self, name_contains) -> PaginatedCollection: } } """ % query.results_query_part(Entity.Ontology) - params = {'search': name_contains, 'filter': {'status': 'ALL'}} - return PaginatedCollection(self, query_str, params, - ['ontologies', 'nodes'], Entity.Ontology, - ['ontologies', 'nextCursor']) + params = {"search": name_contains, "filter": {"status": "ALL"}} + return PaginatedCollection( + self, + query_str, + params, + ["ontologies", "nodes"], + Entity.Ontology, + ["ontologies", "nextCursor"], + ) def get_feature_schema(self, feature_schema_id): """ @@ -1237,10 +1355,9 @@ def get_feature_schema(self, feature_schema_id): res = self.execute( query_str, - {'rootSchemaNodeWhere': { - 'featureSchemaId': feature_schema_id - }})['rootSchemaNode'] - res['id'] = res['normalized']['featureSchemaId'] + {"rootSchemaNodeWhere": {"featureSchemaId": feature_schema_id}}, + )["rootSchemaNode"] + res["id"] = res["normalized"]["featureSchemaId"] return Entity.FeatureSchema(self, res) def get_feature_schemas(self, name_contains) -> PaginatedCollection: @@ -1261,25 +1378,30 @@ def get_feature_schemas(self, name_contains) -> PaginatedCollection: } } """ % query.results_query_part(Entity.FeatureSchema) - params = {'search': name_contains, 'filter': {'status': 'ALL'}} + params = {"search": name_contains, "filter": {"status": "ALL"}} def rootSchemaPayloadToFeatureSchema(client, payload): # Technically we are querying for a Schema Node. # But the features are the same so we just grab the feature schema id - payload['id'] = payload['normalized']['featureSchemaId'] + payload["id"] = payload["normalized"]["featureSchemaId"] return Entity.FeatureSchema(client, payload) - return PaginatedCollection(self, query_str, params, - ['rootSchemaNodes', 'nodes'], - rootSchemaPayloadToFeatureSchema, - ['rootSchemaNodes', 'nextCursor']) + return PaginatedCollection( + self, + query_str, + params, + ["rootSchemaNodes", "nodes"], + rootSchemaPayloadToFeatureSchema, + ["rootSchemaNodes", "nextCursor"], + ) def create_ontology_from_feature_schemas( - self, - name, - feature_schema_ids, - media_type: MediaType = None, - ontology_kind: OntologyKind = None) -> Ontology: + self, + name, + feature_schema_ids, + media_type: MediaType = None, + ontology_kind: OntologyKind = None, + ) -> Ontology: """ Creates an ontology from a list of feature schema ids @@ -1298,22 +1420,27 @@ def create_ontology_from_feature_schemas( tools, classifications = [], [] for feature_schema_id in feature_schema_ids: feature_schema = self.get_feature_schema(feature_schema_id) - tool = ['tool'] - if 'tool' in feature_schema.normalized: - tool = feature_schema.normalized['tool'] + tool = ["tool"] + if "tool" in feature_schema.normalized: + tool = feature_schema.normalized["tool"] try: Tool.Type(tool) tools.append(feature_schema.normalized) except ValueError: raise ValueError( - f"Tool `{tool}` not in list of supported tools.") - elif 'type' in feature_schema.normalized: - classification = feature_schema.normalized['type'] - if classification in Classification.Type._value2member_map_.keys( + f"Tool `{tool}` not in list of supported tools." + ) + elif "type" in feature_schema.normalized: + classification = feature_schema.normalized["type"] + if ( + classification + in Classification.Type._value2member_map_.keys() ): Classification.Type(classification) classifications.append(feature_schema.normalized) - elif classification in PromptResponseClassification.Type._value2member_map_.keys( + elif ( + classification + in PromptResponseClassification.Type._value2member_map_.keys() ): PromptResponseClassification.Type(classification) classifications.append(feature_schema.normalized) @@ -1325,13 +1452,15 @@ def create_ontology_from_feature_schemas( raise ValueError( "Neither `tool` or `classification` found in the normalized feature schema" ) - normalized = {'tools': tools, 'classifications': classifications} + normalized = {"tools": tools, "classifications": classifications} # validation for ontology_kind and media_type is done within self.create_ontology - return self.create_ontology(name=name, - normalized=normalized, - media_type=media_type, - ontology_kind=ontology_kind) + return self.create_ontology( + name=name, + normalized=normalized, + media_type=media_type, + ontology_kind=ontology_kind, + ) def delete_unused_feature_schema(self, feature_schema_id: str) -> None: """ @@ -1342,14 +1471,18 @@ def delete_unused_feature_schema(self, feature_schema_id: str) -> None: >>> client.delete_unused_feature_schema("cleabc1my012ioqvu5anyaabc") """ - endpoint = self.rest_endpoint + "/feature-schemas/" + urllib.parse.quote( - feature_schema_id) + endpoint = ( + self.rest_endpoint + + "/feature-schemas/" + + urllib.parse.quote(feature_schema_id) + ) response = self._connection.delete(endpoint) if response.status_code != requests.codes.no_content: raise labelbox.exceptions.LabelboxError( - "Failed to delete the feature schema, message: " + - str(response.json()['message'])) + "Failed to delete the feature schema, message: " + + str(response.json()["message"]) + ) def delete_unused_ontology(self, ontology_id: str) -> None: """ @@ -1359,17 +1492,22 @@ def delete_unused_ontology(self, ontology_id: str) -> None: Example: >>> client.delete_unused_ontology("cleabc1my012ioqvu5anyaabc") """ - endpoint = self.rest_endpoint + "/ontologies/" + urllib.parse.quote( - ontology_id) + endpoint = ( + self.rest_endpoint + + "/ontologies/" + + urllib.parse.quote(ontology_id) + ) response = self._connection.delete(endpoint) if response.status_code != requests.codes.no_content: raise labelbox.exceptions.LabelboxError( - "Failed to delete the ontology, message: " + - str(response.json()['message'])) + "Failed to delete the ontology, message: " + + str(response.json()["message"]) + ) - def update_feature_schema_title(self, feature_schema_id: str, - title: str) -> FeatureSchema: + def update_feature_schema_title( + self, feature_schema_id: str, title: str + ) -> FeatureSchema: """ Updates a title of a feature schema Args: @@ -1381,16 +1519,21 @@ def update_feature_schema_title(self, feature_schema_id: str, >>> client.update_feature_schema_title("cleabc1my012ioqvu5anyaabc", "New Title") """ - endpoint = self.rest_endpoint + "/feature-schemas/" + urllib.parse.quote( - feature_schema_id) + '/definition' + endpoint = ( + self.rest_endpoint + + "/feature-schemas/" + + urllib.parse.quote(feature_schema_id) + + "/definition" + ) response = self._connection.patch(endpoint, json={"title": title}) if response.status_code == requests.codes.ok: return self.get_feature_schema(feature_schema_id) else: raise labelbox.exceptions.LabelboxError( - "Failed to update the feature schema, message: " + - str(response.json()['message'])) + "Failed to update the feature schema, message: " + + str(response.json()["message"]) + ) def upsert_feature_schema(self, feature_schema: Dict) -> FeatureSchema: """ @@ -1408,23 +1551,29 @@ def upsert_feature_schema(self, feature_schema: Dict) -> FeatureSchema: >>> client.upsert_feature_schema(tool.asdict()) """ - feature_schema_id = feature_schema.get( - "featureSchemaId") or "new_feature_schema_id" - endpoint = self.rest_endpoint + "/feature-schemas/" + urllib.parse.quote( - feature_schema_id) + feature_schema_id = ( + feature_schema.get("featureSchemaId") or "new_feature_schema_id" + ) + endpoint = ( + self.rest_endpoint + + "/feature-schemas/" + + urllib.parse.quote(feature_schema_id) + ) response = self._connection.put( - endpoint, json={"normalized": json.dumps(feature_schema)}) + endpoint, json={"normalized": json.dumps(feature_schema)} + ) if response.status_code == requests.codes.ok: - return self.get_feature_schema(response.json()['schemaId']) + return self.get_feature_schema(response.json()["schemaId"]) else: raise labelbox.exceptions.LabelboxError( - "Failed to upsert the feature schema, message: " + - str(response.json()['message'])) + "Failed to upsert the feature schema, message: " + + str(response.json()["message"]) + ) - def insert_feature_schema_into_ontology(self, feature_schema_id: str, - ontology_id: str, - position: int) -> None: + def insert_feature_schema_into_ontology( + self, feature_schema_id: str, ontology_id: str, position: int + ) -> None: """ Inserts a feature schema into an ontology. If the feature schema is already in the ontology, it will be moved to the new position. @@ -1436,14 +1585,19 @@ def insert_feature_schema_into_ontology(self, feature_schema_id: str, >>> client.insert_feature_schema_into_ontology("cleabc1my012ioqvu5anyaabc", "clefdvwl7abcgefgu3lyvcde", 2) """ - endpoint = self.rest_endpoint + '/ontologies/' + urllib.parse.quote( - ontology_id) + "/feature-schemas/" + urllib.parse.quote( - feature_schema_id) + endpoint = ( + self.rest_endpoint + + "/ontologies/" + + urllib.parse.quote(ontology_id) + + "/feature-schemas/" + + urllib.parse.quote(feature_schema_id) + ) response = self._connection.post(endpoint, json={"position": position}) if response.status_code != requests.codes.created: raise labelbox.exceptions.LabelboxError( "Failed to insert the feature schema into the ontology, message: " - + str(response.json()['message'])) + + str(response.json()["message"]) + ) def get_unused_ontologies(self, after: str = None) -> List[str]: """ @@ -1466,8 +1620,9 @@ def get_unused_ontologies(self, after: str = None) -> List[str]: return response.json() else: raise labelbox.exceptions.LabelboxError( - "Failed to get unused ontologies, message: " + - str(response.json()['message'])) + "Failed to get unused ontologies, message: " + + str(response.json()["message"]) + ) def get_unused_feature_schemas(self, after: str = None) -> List[str]: """ @@ -1490,14 +1645,17 @@ def get_unused_feature_schemas(self, after: str = None) -> List[str]: return response.json() else: raise labelbox.exceptions.LabelboxError( - "Failed to get unused feature schemas, message: " + - str(response.json()['message'])) + "Failed to get unused feature schemas, message: " + + str(response.json()["message"]) + ) - def create_ontology(self, - name, - normalized, - media_type: MediaType = None, - ontology_kind: OntologyKind = None) -> Ontology: + def create_ontology( + self, + name, + normalized, + media_type: MediaType = None, + ontology_kind: OntologyKind = None, + ) -> Ontology: """ Creates an ontology from normalized data >>> normalized = {"tools" : [{'tool': 'polygon', 'name': 'cat', 'color': 'black'}], "classifications" : []} @@ -1515,7 +1673,7 @@ def create_ontology(self, name (str): Name of the ontology normalized (dict): A normalized ontology payload. See above for details. media_type (MediaType or None): Media type of a new ontology - ontology_kind (OntologyKind or None): set to OntologyKind.ModelEvaluation if the ontology is for chat evaluation or + ontology_kind (OntologyKind or None): set to OntologyKind.ModelEvaluation if the ontology is for chat evaluation or OntologyKind.ResponseCreation if ontology is for response creation, leave as None otherwise. Returns: @@ -1533,9 +1691,11 @@ def create_ontology(self, if ontology_kind and OntologyKind.is_supported(ontology_kind): media_type = OntologyKind.evaluate_ontology_kind_with_media_type( - ontology_kind, media_type) + ontology_kind, media_type + ) editor_task_type_value = EditorTaskTypeMapper.to_editor_task_type( - ontology_kind, media_type).value + ontology_kind, media_type + ).value elif ontology_kind: raise OntologyKind.get_ontology_kind_validation_error(ontology_kind) else: @@ -1545,17 +1705,17 @@ def create_ontology(self, upsertOntology(data: $data){ %s } } """ % query.results_query_part(Entity.Ontology) params = { - 'data': { - 'name': name, - 'normalized': json.dumps(normalized), - 'mediaType': media_type_value + "data": { + "name": name, + "normalized": json.dumps(normalized), + "mediaType": media_type_value, } } if editor_task_type_value: - params['data']['editorTaskType'] = editor_task_type_value + params["data"]["editorTaskType"] = editor_task_type_value res = self.execute(query_str, params) - return Entity.Ontology(self, res['upsertOntology']) + return Entity.Ontology(self, res["upsertOntology"]) def create_feature_schema(self, normalized): """ @@ -1592,15 +1752,15 @@ def create_feature_schema(self, normalized): upsertRootSchemaNode(data: $data){ %s } } """ % query.results_query_part(Entity.FeatureSchema) normalized = {k: v for k, v in normalized.items() if v} - params = {'data': {'normalized': json.dumps(normalized)}} - res = self.execute(query_str, params)['upsertRootSchemaNode'] + params = {"data": {"normalized": json.dumps(normalized)}} + res = self.execute(query_str, params)["upsertRootSchemaNode"] # Technically we are querying for a Schema Node. # But the features are the same so we just grab the feature schema id - res['id'] = res['normalized']['featureSchemaId'] + res["id"] = res["normalized"]["featureSchemaId"] return Entity.FeatureSchema(self, res) def get_model_run(self, model_run_id: str) -> ModelRun: - """ Gets a single ModelRun with the given ID. + """Gets a single ModelRun with the given ID. >>> model_run = client.get_model_run("") @@ -1612,9 +1772,10 @@ def get_model_run(self, model_run_id: str) -> ModelRun: return self._get_single(Entity.ModelRun, model_run_id) def assign_global_keys_to_data_rows( - self, - global_key_to_data_row_inputs: List[Dict[str, str]], - timeout_seconds=60) -> Dict[str, Union[str, List[Any]]]: + self, + global_key_to_data_row_inputs: List[Dict[str, str]], + timeout_seconds=60, + ) -> Dict[str, Union[str, List[Any]]]: """ Assigns global keys to data rows. @@ -1645,21 +1806,29 @@ def assign_global_keys_to_data_rows( [{'data_row_id': 'cl7tpjzw30031ka6g4evqdfoy', 'global_key': 'gk"', 'error': 'Invalid global key'}] """ - def _format_successful_rows(rows: Dict[str, str], - sanitized: bool) -> List[Dict[str, str]]: - return [{ - 'data_row_id': r['dataRowId'], - 'global_key': r['globalKey'], - 'sanitized': sanitized - } for r in rows] + def _format_successful_rows( + rows: Dict[str, str], sanitized: bool + ) -> List[Dict[str, str]]: + return [ + { + "data_row_id": r["dataRowId"], + "global_key": r["globalKey"], + "sanitized": sanitized, + } + for r in rows + ] - def _format_failed_rows(rows: Dict[str, str], - error_msg: str) -> List[Dict[str, str]]: - return [{ - 'data_row_id': r['dataRowId'], - 'global_key': r['globalKey'], - 'error': error_msg - } for r in rows] + def _format_failed_rows( + rows: Dict[str, str], error_msg: str + ) -> List[Dict[str, str]]: + return [ + { + "data_row_id": r["dataRowId"], + "global_key": r["globalKey"], + "error": error_msg, + } + for r in rows + ] # Validate input dict validation_errors = [] @@ -1679,9 +1848,10 @@ def _format_failed_rows(rows: Dict[str, str], } """ params = { - 'globalKeyDataRowLinks': [{ - utils.camel_case(key): value for key, value in input.items() - } for input in global_key_to_data_row_inputs] + "globalKeyDataRowLinks": [ + {utils.camel_case(key): value for key, value in input.items()} + for input in global_key_to_data_row_inputs + ] } assign_global_keys_to_data_rows_job = self.execute(query_str, params) @@ -1709,9 +1879,9 @@ def _format_failed_rows(rows: Dict[str, str], }}} """ result_params = { - "jobId": - assign_global_keys_to_data_rows_job["assignGlobalKeysToDataRows" - ]["jobId"] + "jobId": assign_global_keys_to_data_rows_job[ + "assignGlobalKeysToDataRows" + ]["jobId"] } # Poll job status until finished, then retrieve results @@ -1719,27 +1889,36 @@ def _format_failed_rows(rows: Dict[str, str], start_time = time.time() while True: res = self.execute(result_query_str, result_params) - if res["assignGlobalKeysToDataRowsResult"][ - "jobStatus"] == "COMPLETE": + if ( + res["assignGlobalKeysToDataRowsResult"]["jobStatus"] + == "COMPLETE" + ): results, errors = [], [] - res = res['assignGlobalKeysToDataRowsResult']['data'] + res = res["assignGlobalKeysToDataRowsResult"]["data"] # Successful assignments results.extend( - _format_successful_rows(rows=res['sanitizedAssignments'], - sanitized=True)) + _format_successful_rows( + rows=res["sanitizedAssignments"], sanitized=True + ) + ) results.extend( - _format_successful_rows(rows=res['unmodifiedAssignments'], - sanitized=False)) + _format_successful_rows( + rows=res["unmodifiedAssignments"], sanitized=False + ) + ) # Failed assignments errors.extend( _format_failed_rows( - rows=res['invalidGlobalKeyAssignments'], - error_msg= - "Invalid assignment. Either DataRow does not exist, or globalKey is invalid" - )) + rows=res["invalidGlobalKeyAssignments"], + error_msg="Invalid assignment. Either DataRow does not exist, or globalKey is invalid", + ) + ) errors.extend( - _format_failed_rows(rows=res['accessDeniedAssignments'], - error_msg="Access denied to Data Row")) + _format_failed_rows( + rows=res["accessDeniedAssignments"], + error_msg="Access denied to Data Row", + ) + ) if not errors: status = CollectionJobStatus.SUCCESS.value @@ -1758,10 +1937,12 @@ def _format_failed_rows(rows: Dict[str, str], "results": results, "errors": errors, } - elif res["assignGlobalKeysToDataRowsResult"][ - "jobStatus"] == "FAILED": + elif ( + res["assignGlobalKeysToDataRowsResult"]["jobStatus"] == "FAILED" + ): raise labelbox.exceptions.LabelboxError( - "Job assign_global_keys_to_data_rows failed.") + "Job assign_global_keys_to_data_rows failed." + ) current_time = time.time() if current_time - start_time > timeout_seconds: raise labelbox.exceptions.TimeoutError( @@ -1770,9 +1951,8 @@ def _format_failed_rows(rows: Dict[str, str], time.sleep(sleep_time) def get_data_row_ids_for_global_keys( - self, - global_keys: List[str], - timeout_seconds=60) -> Dict[str, Union[str, List[Any]]]: + self, global_keys: List[str], timeout_seconds=60 + ) -> Dict[str, Union[str, List[Any]]]: """ Gets data row ids for a list of global keys. @@ -1805,9 +1985,10 @@ def get_data_row_ids_for_global_keys( [{'global_key': 'asdf', 'error': 'Data Row not found'}] """ - def _format_failed_rows(rows: List[str], - error_msg: str) -> List[Dict[str, str]]: - return [{'global_key': r, 'error': error_msg} for r in rows] + def _format_failed_rows( + rows: List[str], error_msg: str + ) -> List[Dict[str, str]]: + return [{"global_key": r, "error": error_msg} for r in rows] # Start get data rows for global keys job query_str = """query getDataRowsForGlobalKeysPyApi($globalKeys: [ID!]!) { @@ -1825,8 +2006,9 @@ def _format_failed_rows(rows: List[str], } jobStatus}} """ result_params = { - "jobId": - data_rows_for_global_keys_job["dataRowsForGlobalKeys"]["jobId"] + "jobId": data_rows_for_global_keys_job["dataRowsForGlobalKeys"][ + "jobId" + ] } # Poll job status until finished, then retrieve results @@ -1834,20 +2016,25 @@ def _format_failed_rows(rows: List[str], start_time = time.time() while True: res = self.execute(result_query_str, result_params) - if res["dataRowsForGlobalKeysResult"]['jobStatus'] == "COMPLETE": - data = res["dataRowsForGlobalKeysResult"]['data'] + if res["dataRowsForGlobalKeysResult"]["jobStatus"] == "COMPLETE": + data = res["dataRowsForGlobalKeysResult"]["data"] results, errors = [], [] - results.extend([row['id'] for row in data['fetchedDataRows']]) + results.extend([row["id"] for row in data["fetchedDataRows"]]) errors.extend( - _format_failed_rows(data['notFoundGlobalKeys'], - "Data Row not found")) + _format_failed_rows( + data["notFoundGlobalKeys"], "Data Row not found" + ) + ) errors.extend( - _format_failed_rows(data['accessDeniedGlobalKeys'], - "Access denied to Data Row")) + _format_failed_rows( + data["accessDeniedGlobalKeys"], + "Access denied to Data Row", + ) + ) # Invalid results may contain empty string, so we must filter # them prior to checking for PARTIAL_SUCCESS - filtered_results = list(filter(lambda r: r != '', results)) + filtered_results = list(filter(lambda r: r != "", results)) if not errors: status = CollectionJobStatus.SUCCESS.value elif errors and len(filtered_results) > 0: @@ -1861,9 +2048,10 @@ def _format_failed_rows(rows: List[str], ) return {"status": status, "results": results, "errors": errors} - elif res["dataRowsForGlobalKeysResult"]['jobStatus'] == "FAILED": + elif res["dataRowsForGlobalKeysResult"]["jobStatus"] == "FAILED": raise labelbox.exceptions.LabelboxError( - "Job dataRowsForGlobalKeys failed.") + "Job dataRowsForGlobalKeys failed." + ) current_time = time.time() if current_time - start_time > timeout_seconds: raise labelbox.exceptions.TimeoutError( @@ -1872,9 +2060,8 @@ def _format_failed_rows(rows: List[str], time.sleep(sleep_time) def clear_global_keys( - self, - global_keys: List[str], - timeout_seconds=60) -> Dict[str, Union[str, List[Any]]]: + self, global_keys: List[str], timeout_seconds=60 + ) -> Dict[str, Union[str, List[Any]]]: """ Clears global keys for the data rows tha correspond to the global keys provided. @@ -1900,9 +2087,10 @@ def clear_global_keys( [{'global_key': 'notfoundkey', 'error': 'Failed to find data row matching provided global key'}] """ - def _format_failed_rows(rows: List[str], - error_msg: str) -> List[Dict[str, str]]: - return [{'global_key': r, 'error': error_msg} for r in rows] + def _format_failed_rows( + rows: List[str], error_msg: str + ) -> List[Dict[str, str]]: + return [{"global_key": r, "error": error_msg} for r in rows] # Start get data rows for global keys job query_str = """mutation clearGlobalKeysPyApi($globalKeys: [ID!]!) { @@ -1928,22 +2116,28 @@ def _format_failed_rows(rows: List[str], start_time = time.time() while True: res = self.execute(result_query_str, result_params) - if res["clearGlobalKeysResult"]['jobStatus'] == "COMPLETE": - data = res["clearGlobalKeysResult"]['data'] + if res["clearGlobalKeysResult"]["jobStatus"] == "COMPLETE": + data = res["clearGlobalKeysResult"]["data"] results, errors = [], [] - results.extend(data['clearedGlobalKeys']) + results.extend(data["clearedGlobalKeys"]) errors.extend( - _format_failed_rows(data['failedToClearGlobalKeys'], - "Clearing global key failed")) + _format_failed_rows( + data["failedToClearGlobalKeys"], + "Clearing global key failed", + ) + ) errors.extend( _format_failed_rows( - data['notFoundGlobalKeys'], - "Failed to find data row matching provided global key")) + data["notFoundGlobalKeys"], + "Failed to find data row matching provided global key", + ) + ) errors.extend( _format_failed_rows( - data['accessDeniedGlobalKeys'], - "Denied access to modify data row matching provided global key" - )) + data["accessDeniedGlobalKeys"], + "Denied access to modify data row matching provided global key", + ) + ) if not errors: status = CollectionJobStatus.SUCCESS.value @@ -1958,13 +2152,15 @@ def _format_failed_rows(rows: List[str], ) return {"status": status, "results": results, "errors": errors} - elif res["clearGlobalKeysResult"]['jobStatus'] == "FAILED": + elif res["clearGlobalKeysResult"]["jobStatus"] == "FAILED": raise labelbox.exceptions.LabelboxError( - "Job clearGlobalKeys failed.") + "Job clearGlobalKeys failed." + ) current_time = time.time() if current_time - start_time > timeout_seconds: raise labelbox.exceptions.TimeoutError( - "Timed out waiting for clear_global_keys job to complete.") + "Timed out waiting for clear_global_keys job to complete." + ) time.sleep(sleep_time) def get_catalog(self) -> Catalog: @@ -1990,11 +2186,12 @@ def get_catalog_slice(self, slice_id) -> CatalogSlice: } } """ - res = self.execute(query_str, {'id': slice_id}) - return Entity.CatalogSlice(self, res['getSavedQuery']) + res = self.execute(query_str, {"id": slice_id}) + return Entity.CatalogSlice(self, res["getSavedQuery"]) - def is_feature_schema_archived(self, ontology_id: str, - feature_schema_id: str) -> bool: + def is_feature_schema_archived( + self, ontology_id: str, feature_schema_id: str + ) -> bool: """ Returns true if a feature schema is archived in the specified ontology, returns false otherwise. @@ -2005,33 +2202,39 @@ def is_feature_schema_archived(self, ontology_id: str, bool """ - ontology_endpoint = self.rest_endpoint + "/ontologies/" + urllib.parse.quote( - ontology_id) + ontology_endpoint = ( + self.rest_endpoint + + "/ontologies/" + + urllib.parse.quote(ontology_id) + ) response = self._connection.get(ontology_endpoint) if response.status_code == requests.codes.ok: - feature_schema_nodes = response.json()['featureSchemaNodes'] - tools = feature_schema_nodes['tools'] - classifications = feature_schema_nodes['classifications'] - relationships = feature_schema_nodes['relationships'] + feature_schema_nodes = response.json()["featureSchemaNodes"] + tools = feature_schema_nodes["tools"] + classifications = feature_schema_nodes["classifications"] + relationships = feature_schema_nodes["relationships"] feature_schema_node_list = tools + classifications + relationships filtered_feature_schema_nodes = [ feature_schema_node for feature_schema_node in feature_schema_node_list - if feature_schema_node['featureSchemaId'] == feature_schema_id + if feature_schema_node["featureSchemaId"] == feature_schema_id ] if filtered_feature_schema_nodes: - return bool(filtered_feature_schema_nodes[0]['archived']) + return bool(filtered_feature_schema_nodes[0]["archived"]) else: raise labelbox.exceptions.LabelboxError( - "The specified feature schema was not in the ontology.") + "The specified feature schema was not in the ontology." + ) elif response.status_code == 404: raise labelbox.exceptions.ResourceNotFoundError( - Ontology, ontology_id) + Ontology, ontology_id + ) else: raise labelbox.exceptions.LabelboxError( - "Failed to get the feature schema archived status.") + "Failed to get the feature schema archived status." + ) def get_model_slice(self, slice_id) -> ModelSlice: """ @@ -2057,13 +2260,14 @@ def get_model_slice(self, slice_id) -> ModelSlice: res = self.execute(query_str, {"id": slice_id}) if res is None or res["getSavedQuery"] is None: raise labelbox.exceptions.ResourceNotFoundError( - ModelSlice, slice_id) + ModelSlice, slice_id + ) return Entity.ModelSlice(self, res["getSavedQuery"]) def delete_feature_schema_from_ontology( - self, ontology_id: str, - feature_schema_id: str) -> DeleteFeatureFromOntologyResult: + self, ontology_id: str, feature_schema_id: str + ) -> DeleteFeatureFromOntologyResult: """ Deletes or archives a feature schema from an ontology. If the feature schema is a root level node with associated labels, it will be archived. @@ -2080,31 +2284,38 @@ def delete_feature_schema_from_ontology( Example: >>> client.delete_feature_schema_from_ontology(, ) """ - ontology_endpoint = self.rest_endpoint + "/ontologies/" + urllib.parse.quote( - ontology_id) + "/feature-schemas/" + urllib.parse.quote( - feature_schema_id) + ontology_endpoint = ( + self.rest_endpoint + + "/ontologies/" + + urllib.parse.quote(ontology_id) + + "/feature-schemas/" + + urllib.parse.quote(feature_schema_id) + ) response = self._connection.delete(ontology_endpoint) if response.status_code == requests.codes.ok: response_json = response.json() - if response_json['archived'] == True: + if response_json["archived"] == True: logger.info( - 'Feature schema was archived from the ontology because it had associated labels.' + "Feature schema was archived from the ontology because it had associated labels." ) - elif response_json['deleted'] == True: + elif response_json["deleted"] == True: logger.info( - 'Feature schema was successfully removed from the ontology') + "Feature schema was successfully removed from the ontology" + ) result = DeleteFeatureFromOntologyResult() - result.archived = bool(response_json['archived']) - result.deleted = bool(response_json['deleted']) + result.archived = bool(response_json["archived"]) + result.deleted = bool(response_json["deleted"]) return result else: raise labelbox.exceptions.LabelboxError( - "Failed to remove feature schema from ontology, message: " + - str(response.json()['message'])) + "Failed to remove feature schema from ontology, message: " + + str(response.json()["message"]) + ) - def unarchive_feature_schema_node(self, ontology_id: str, - root_feature_schema_id: str) -> None: + def unarchive_feature_schema_node( + self, ontology_id: str, root_feature_schema_id: str + ) -> None: """ Unarchives a feature schema node in an ontology. Only root level feature schema nodes can be unarchived. @@ -2114,18 +2325,25 @@ def unarchive_feature_schema_node(self, ontology_id: str, Returns: None """ - ontology_endpoint = self.rest_endpoint + "/ontologies/" + urllib.parse.quote( - ontology_id) + '/feature-schemas/' + urllib.parse.quote( - root_feature_schema_id) + '/unarchive' + ontology_endpoint = ( + self.rest_endpoint + + "/ontologies/" + + urllib.parse.quote(ontology_id) + + "/feature-schemas/" + + urllib.parse.quote(root_feature_schema_id) + + "/unarchive" + ) response = self._connection.patch(ontology_endpoint) if response.status_code == requests.codes.ok: - if not bool(response.json()['unarchived']): + if not bool(response.json()["unarchived"]): raise labelbox.exceptions.LabelboxError( - "Failed unarchive the feature schema.") + "Failed unarchive the feature schema." + ) else: raise labelbox.exceptions.LabelboxError( "Failed unarchive the feature schema node, message: ", - response.text) + response.text, + ) def get_batch(self, project_id: str, batch_id: str) -> Entity.Batch: # obtain batch entity to return @@ -2138,24 +2356,28 @@ def get_batch(self, project_id: str, batch_id: str) -> Entity.Batch: } } } - """ % ("getProjectBatchPyApi", - query.results_query_part(Entity.Batch)) + """ % ( + "getProjectBatchPyApi", + query.results_query_part(Entity.Batch), + ) batch = self.execute( - get_batch_str, { - "projectId": project_id, - "batchId": batch_id - }, + get_batch_str, + {"projectId": project_id, "batchId": batch_id}, timeout=180.0, - experimental=True)["project"]["batches"]["nodes"][0] + experimental=True, + )["project"]["batches"]["nodes"][0] return Entity.Batch(self, project_id, batch) - def send_to_annotate_from_catalog(self, destination_project_id: str, - task_queue_id: Optional[str], - batch_name: str, - data_rows: Union[DataRowIds, GlobalKeys], - params: Dict[str, Any]): + def send_to_annotate_from_catalog( + self, + destination_project_id: str, + task_queue_id: Optional[str], + batch_name: str, + data_rows: Union[DataRowIds, GlobalKeys], + params: Dict[str, Any], + ): """ Sends data rows from catalog to a specified project for annotation. @@ -2196,56 +2418,55 @@ def send_to_annotate_from_catalog(self, destination_project_id: str, """ destination_task_queue = build_destination_task_queue_input( - task_queue_id) + task_queue_id + ) data_rows_query = self.build_catalog_query(data_rows) - predictions_input = build_predictions_input( - validated_params.predictions_ontology_mapping, - validated_params.source_model_run_id - ) if validated_params.source_model_run_id else None - - annotations_input = build_annotations_input( - validated_params.annotations_ontology_mapping, validated_params. - source_project_id) if validated_params.source_project_id else None + predictions_input = ( + build_predictions_input( + validated_params.predictions_ontology_mapping, + validated_params.source_model_run_id, + ) + if validated_params.source_model_run_id + else None + ) + + annotations_input = ( + build_annotations_input( + validated_params.annotations_ontology_mapping, + validated_params.source_project_id, + ) + if validated_params.source_project_id + else None + ) res = self.execute( - mutation_str, { + mutation_str, + { "input": { - "destinationProjectId": - destination_project_id, + "destinationProjectId": destination_project_id, "batchInput": { "batchName": batch_name, - "batchPriority": validated_params.batch_priority - }, - "destinationTaskQueue": - destination_task_queue, - "excludeDataRowsInProject": - validated_params.exclude_data_rows_in_project, - "annotationsInput": - annotations_input, - "predictionsInput": - predictions_input, - "conflictLabelsResolutionStrategy": - validated_params.override_existing_annotations_rule, - "searchQuery": { - "scope": None, - "query": [data_rows_query] + "batchPriority": validated_params.batch_priority, }, + "destinationTaskQueue": destination_task_queue, + "excludeDataRowsInProject": validated_params.exclude_data_rows_in_project, + "annotationsInput": annotations_input, + "predictionsInput": predictions_input, + "conflictLabelsResolutionStrategy": validated_params.override_existing_annotations_rule, + "searchQuery": {"scope": None, "query": [data_rows_query]}, "ordering": { "type": "RANDOM", - "random": { - "seed": random.randint(0, 10000) - }, - "sorting": None + "random": {"seed": random.randint(0, 10000)}, + "sorting": None, }, - "sorting": - None, - "limit": - None + "sorting": None, + "limit": None, } - })['sendToAnnotateFromCatalog'] + }, + )["sendToAnnotateFromCatalog"] - return Entity.Task.get_task(self, res['taskId']) + return Entity.Task.get_task(self, res["taskId"]) @staticmethod def build_catalog_query(data_rows: Union[DataRowIds, GlobalKeys]): @@ -2262,13 +2483,13 @@ def build_catalog_query(data_rows: Union[DataRowIds, GlobalKeys]): data_rows_query = { "type": "data_row_id", "operator": "is", - "ids": list(data_rows) + "ids": list(data_rows), } elif isinstance(data_rows, GlobalKeys): data_rows_query = { "type": "global_key", "operator": "is", - "ids": list(data_rows) + "ids": list(data_rows), } else: raise ValueError( @@ -2276,9 +2497,12 @@ def build_catalog_query(data_rows: Union[DataRowIds, GlobalKeys]): ) return data_rows_query - def run_foundry_app(self, model_run_name: str, data_rows: Union[DataRowIds, - GlobalKeys], - app_id: str) -> Task: + def run_foundry_app( + self, + model_run_name: str, + data_rows: Union[DataRowIds, GlobalKeys], + app_id: str, + ) -> Task: """ Run a foundry app @@ -2345,11 +2569,13 @@ def get_embedding_by_name(self, name: str) -> Embedding: for e in embeddings: if e.name == name: return e - raise labelbox.exceptions.ResourceNotFoundError(Embedding, - dict(name=name)) + raise labelbox.exceptions.ResourceNotFoundError( + Embedding, dict(name=name) + ) - def upsert_label_feedback(self, label_id: str, feedback: str, - scores: Dict[str, float]) -> List[LabelScore]: + def upsert_label_feedback( + self, label_id: str, feedback: str, scores: Dict[str, float] + ) -> List[LabelScore]: """ Submits the label feedback which is a free-form text and numeric label scores. @@ -2385,15 +2611,14 @@ def upsert_label_feedback(self, label_id: str, feedback: str, } } """ - res = self.execute(mutation_str, { - "labelId": label_id, - "feedback": feedback, - "scores": scores - }) + res = self.execute( + mutation_str, + {"labelId": label_id, "feedback": feedback, "scores": scores}, + ) scores_raw = res["upsertAutoQaLabelFeedback"]["scores"] return [ - labelbox.LabelScore(name=x['name'], score=x['score']) + labelbox.LabelScore(name=x["name"], score=x["score"]) for x in scores_raw ] @@ -2406,12 +2631,12 @@ def get_labeling_service_dashboards( Optional parameters: search_query: A list of search filters representing the search - + NOTE: - Retrieves all projects for the organization or as filtered by the search query - INCLUDING those not requesting labeling services - Sorted by project created date in ascending order. - + Examples: Retrieves all labeling service dashboards for a given workspace id: >>> workspace_filter = WorkspaceFilter( @@ -2442,7 +2667,7 @@ def get_task_by_id(self, task_id: str) -> Union[Task, DataUpsertTask]: Returns: Task or DataUpsertTask - + Throws: ResourceNotFoundError: If the task does not exist. @@ -2471,9 +2696,10 @@ def get_task_by_id(self, task_id: str) -> Union[Task, DataUpsertTask]: data = result.get("user", {}).get("createdTasks", []) if not data: raise labelbox.exceptions.ResourceNotFoundError( - message=f"The task {task_id} does not exist.") + message=f"The task {task_id} does not exist." + ) task_data = data[0] - if task_data["type"].lower() == 'adv-upsert-data-rows': + if task_data["type"].lower() == "adv-upsert-data-rows": task = DataUpsertTask(self, task_data) else: task = Task(self, task_data) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/__init__.py index 5b51814ec..7908bc242 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/__init__.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/__init__.py @@ -64,4 +64,11 @@ from .llm_prompt_response.prompt import PromptText from .llm_prompt_response.prompt import PromptClassificationAnnotation -from .mmc import MessageInfo, OrderedMessageInfo, MessageSingleSelectionTask, MessageMultiSelectionTask, MessageRankingTask, MessageEvaluationTaskAnnotation +from .mmc import ( + MessageInfo, + OrderedMessageInfo, + MessageSingleSelectionTask, + MessageMultiSelectionTask, + MessageRankingTask, + MessageEvaluationTaskAnnotation, +) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/annotation.py b/libs/labelbox/src/labelbox/data/annotation_types/annotation.py index 8a718751a..2c2f110a0 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/annotation.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/annotation.py @@ -5,7 +5,9 @@ from labelbox.data.mixins import ConfidenceMixin, CustomMetricsMixin -from labelbox.data.annotation_types.classification.classification import ClassificationAnnotation +from labelbox.data.annotation_types.classification.classification import ( + ClassificationAnnotation, +) from .ner import DocumentEntity, TextEntity, ConversationEntity from typing import Optional diff --git a/libs/labelbox/src/labelbox/data/annotation_types/base_annotation.py b/libs/labelbox/src/labelbox/data/annotation_types/base_annotation.py index 27e66c063..ee9bf751b 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/base_annotation.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/base_annotation.py @@ -7,11 +7,11 @@ class BaseAnnotation(FeatureSchema, abc.ABC): - """ Base annotation class. Shouldn't be directly instantiated - """ + """Base annotation class. Shouldn't be directly instantiated""" + _uuid: Optional[UUID] = PrivateAttr() extra: Dict[str, Any] = {} - + model_config = ConfigDict(extra="allow") def __init__(self, **data): diff --git a/libs/labelbox/src/labelbox/data/annotation_types/classification/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/classification/__init__.py index 5bb098730..a814336e4 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/classification/__init__.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/classification/__init__.py @@ -1,2 +1 @@ -from .classification import (Checklist, ClassificationAnswer, Radio, - Text) +from .classification import Checklist, ClassificationAnswer, Radio, Text diff --git a/libs/labelbox/src/labelbox/data/annotation_types/classification/classification.py b/libs/labelbox/src/labelbox/data/annotation_types/classification/classification.py index 23c4c848a..d6a6448dd 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/classification/classification.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/classification/classification.py @@ -18,40 +18,45 @@ class ClassificationAnswer(FeatureSchema, ConfidenceMixin, CustomMetricsMixin): So unlike object annotations, classification annotations track keyframes at a classification answer level. """ + extra: Dict[str, Any] = {} keyframe: Optional[bool] = None - classifications: Optional[List['ClassificationAnnotation']] = None + classifications: Optional[List["ClassificationAnnotation"]] = None class Radio(ConfidenceMixin, CustomMetricsMixin, BaseModel): - """ A classification with only one selected option allowed + """A classification with only one selected option allowed >>> Radio(answer = ClassificationAnswer(name = "dog")) """ + answer: ClassificationAnswer class Checklist(ConfidenceMixin, BaseModel): - """ A classification with many selected options allowed + """A classification with many selected options allowed >>> Checklist(answer = [ClassificationAnswer(name = "cloudy")]) """ + answer: List[ClassificationAnswer] class Text(ConfidenceMixin, CustomMetricsMixin, BaseModel): - """ Free form text + """Free form text >>> Text(answer = "some text answer") """ + answer: str -class ClassificationAnnotation(BaseAnnotation, ConfidenceMixin, - CustomMetricsMixin): +class ClassificationAnnotation( + BaseAnnotation, ConfidenceMixin, CustomMetricsMixin +): """Classification annotations (non localized) >>> ClassificationAnnotation( @@ -65,7 +70,7 @@ class ClassificationAnnotation(BaseAnnotation, ConfidenceMixin, feature_schema_id (Optional[Cuid]) value (Union[Text, Checklist, Radio]) extra (Dict[str, Any]) - """ + """ value: Union[Text, Checklist, Radio] message_id: Optional[str] = None diff --git a/libs/labelbox/src/labelbox/data/annotation_types/collection.py b/libs/labelbox/src/labelbox/data/annotation_types/collection.py index 04c78a583..d90204309 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/collection.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/collection.py @@ -17,7 +17,7 @@ class LabelGenerator(PrefetchGenerator): """ - A container for interacting with a large collection of labels. + A container for interacting with a large collection of labels. For a small number of labels, just use a list of Label objects. """ @@ -26,21 +26,23 @@ def __init__(self, data: Generator[Label, None, None], *args, **kwargs): super().__init__(data, *args, **kwargs) def assign_feature_schema_ids( - self, - ontology_builder: "ontology.OntologyBuilder") -> "LabelGenerator": - + self, ontology_builder: "ontology.OntologyBuilder" + ) -> "LabelGenerator": def _assign_ids(label: Label): label.assign_feature_schema_ids(ontology_builder) return label - warnings.warn("This method is deprecated and will be " - "removed in a future release. Feature schema ids" - " are no longer required for importing.") - self._fns['assign_feature_schema_ids'] = _assign_ids + warnings.warn( + "This method is deprecated and will be " + "removed in a future release. Feature schema ids" + " are no longer required for importing." + ) + self._fns["assign_feature_schema_ids"] = _assign_ids return self - def add_url_to_data(self, signer: Callable[[bytes], - str]) -> "LabelGenerator": + def add_url_to_data( + self, signer: Callable[[bytes], str] + ) -> "LabelGenerator": """ Creates signed urls for the data Only uploads url if one doesn't already exist. @@ -55,11 +57,12 @@ def _add_url_to_data(label: Label): label.add_url_to_data(signer) return label - self._fns['add_url_to_data'] = _add_url_to_data + self._fns["add_url_to_data"] = _add_url_to_data return self - def add_to_dataset(self, dataset: "Entity.Dataset", - signer: Callable[[bytes], str]) -> "LabelGenerator": + def add_to_dataset( + self, dataset: "Entity.Dataset", signer: Callable[[bytes], str] + ) -> "LabelGenerator": """ Creates data rows from each labels data object and attaches the data to the given dataset. Updates the label's data object to have the same external_id and uid as the data row. @@ -75,11 +78,12 @@ def _add_to_dataset(label: Label): label.create_data_row(dataset, signer) return label - self._fns['assign_datarow_ids'] = _add_to_dataset + self._fns["assign_datarow_ids"] = _add_to_dataset return self - def add_url_to_masks(self, signer: Callable[[bytes], - str]) -> "LabelGenerator": + def add_url_to_masks( + self, signer: Callable[[bytes], str] + ) -> "LabelGenerator": """ Creates signed urls for all masks in the LabelGenerator. Multiple masks can reference the same MaskData so this makes sure we only upload that url once. @@ -97,11 +101,12 @@ def _add_url_to_masks(label: Label): label.add_url_to_masks(signer) return label - self._fns['add_url_to_masks'] = _add_url_to_masks + self._fns["add_url_to_masks"] = _add_url_to_masks return self - def register_background_fn(self, fn: Callable[[Label], Label], - name: str) -> "LabelGenerator": + def register_background_fn( + self, fn: Callable[[Label], Label], name: str + ) -> "LabelGenerator": """ Allows users to add arbitrary io functions to the generator. These functions will be exectuted in parallel and added to a prefetch queue. diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/data/__init__.py index 99978caac..2522b2741 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/__init__.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/__init__.py @@ -9,4 +9,4 @@ from .video import VideoData from .llm_prompt_response_creation import LlmPromptResponseCreationData from .llm_prompt_creation import LlmPromptCreationData -from .llm_response_creation import LlmResponseCreationData \ No newline at end of file +from .llm_response_creation import LlmResponseCreationData diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/audio.py b/libs/labelbox/src/labelbox/data/annotation_types/data/audio.py index 76be33110..916fca99d 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/audio.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/audio.py @@ -4,4 +4,4 @@ class AudioData(BaseData, _NoCoercionMixin): - class_name: Literal["AudioData"] = "AudioData" \ No newline at end of file + class_name: Literal["AudioData"] = "AudioData" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/base_data.py b/libs/labelbox/src/labelbox/data/annotation_types/data/base_data.py index 2ccda34c3..7d26ba5ca 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/base_data.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/base_data.py @@ -9,6 +9,7 @@ class BaseData(BaseModel, ABC): Base class for objects representing data. This class shouldn't directly be used """ + external_id: Optional[str] = None uid: Optional[str] = None global_key: Optional[str] = None diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/dicom.py b/libs/labelbox/src/labelbox/data/annotation_types/data/dicom.py index 753475c3e..ae4c377dc 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/dicom.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/dicom.py @@ -4,4 +4,4 @@ class DicomData(BaseData, _NoCoercionMixin): - class_name: Literal["DicomData"] = "DicomData" \ No newline at end of file + class_name: Literal["DicomData"] = "DicomData" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/document.py b/libs/labelbox/src/labelbox/data/annotation_types/data/document.py index 5b2610c5b..810a3ed3e 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/document.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/document.py @@ -4,4 +4,4 @@ class DocumentData(BaseData, _NoCoercionMixin): - class_name: Literal["DocumentData"] = "DocumentData" \ No newline at end of file + class_name: Literal["DocumentData"] = "DocumentData" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/generic_data_row_data.py b/libs/labelbox/src/labelbox/data/annotation_types/data/generic_data_row_data.py index 6a73519c1..9bb6a7e0a 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/generic_data_row_data.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/generic_data_row_data.py @@ -6,8 +6,8 @@ class GenericDataRowData(BaseData, _NoCoercionMixin): - """Generic data row data. This is replacing all other DataType passed into Label - """ + """Generic data row data. This is replacing all other DataType passed into Label""" + url: Optional[str] = None class_name: Literal["GenericDataRowData"] = "GenericDataRowData" @@ -17,7 +17,7 @@ def create_url(self, signer: Callable[[bytes], str]) -> Optional[str]: @model_validator(mode="before") @classmethod def validate_one_datarow_key_present(cls, data): - keys = ['external_id', 'global_key', 'uid'] + keys = ["external_id", "global_key", "uid"] count = sum([key in data for key in keys]) if count < 1: diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/html.py b/libs/labelbox/src/labelbox/data/annotation_types/data/html.py index 1820ce467..7a78fcb7b 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/html.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/html.py @@ -4,4 +4,4 @@ class HTMLData(BaseData, _NoCoercionMixin): - class_name: Literal["HTMLData"] = "HTMLData" \ No newline at end of file + class_name: Literal["HTMLData"] = "HTMLData" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_creation.py b/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_creation.py index 4fd788f1a..a1b0450bc 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_creation.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_creation.py @@ -4,4 +4,4 @@ class LlmPromptCreationData(BaseData, _NoCoercionMixin): - class_name: Literal["LlmPromptCreationData"] = "LlmPromptCreationData" \ No newline at end of file + class_name: Literal["LlmPromptCreationData"] = "LlmPromptCreationData" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_response_creation.py b/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_response_creation.py index 2bad75f6d..a8dfce894 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_response_creation.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_response_creation.py @@ -4,5 +4,6 @@ class LlmPromptResponseCreationData(BaseData, _NoCoercionMixin): - class_name: Literal[ - "LlmPromptResponseCreationData"] = "LlmPromptResponseCreationData" \ No newline at end of file + class_name: Literal["LlmPromptResponseCreationData"] = ( + "LlmPromptResponseCreationData" + ) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_response_creation.py b/libs/labelbox/src/labelbox/data/annotation_types/data/llm_response_creation.py index 43c604e34..a8963ed3f 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_response_creation.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/llm_response_creation.py @@ -4,4 +4,4 @@ class LlmResponseCreationData(BaseData, _NoCoercionMixin): - class_name: Literal["LlmResponseCreationData"] = "LlmResponseCreationData" \ No newline at end of file + class_name: Literal["LlmResponseCreationData"] = "LlmResponseCreationData" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/raster.py b/libs/labelbox/src/labelbox/data/annotation_types/data/raster.py index 234d8b136..ba4c6485f 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/raster.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/raster.py @@ -16,21 +16,23 @@ class RasterData(BaseModel, ABC): - """Represents an image or segmentation mask. - """ + """Represents an image or segmentation mask.""" + im_bytes: Optional[bytes] = None file_path: Optional[str] = None url: Optional[str] = None uid: Optional[str] = None global_key: Optional[str] = None - arr: Optional[TypedArray[Literal['uint8']]] = None - + arr: Optional[TypedArray[Literal["uint8"]]] = None + model_config = ConfigDict(extra="forbid") @classmethod - def from_2D_arr(cls, arr: Union[TypedArray[Literal['uint8']], - TypedArray[Literal['int']]], - **kwargs) -> "RasterData": + def from_2D_arr( + cls, + arr: Union[TypedArray[Literal["uint8"]], TypedArray[Literal["int"]]], + **kwargs, + ) -> "RasterData": """Construct from a 2D numpy array Args: @@ -117,11 +119,12 @@ def value(self) -> np.ndarray: raise ValueError("Must set either url, file_path or im_bytes") def set_fetch_fn(self, fn): - object.__setattr__(self, 'fetch_remote', lambda: fn(self)) + object.__setattr__(self, "fetch_remote", lambda: fn(self)) - @retry.Retry(deadline=15., - predicate=retry.if_exception_type(ConnectTimeout, - InternalServerError)) + @retry.Retry( + deadline=15.0, + predicate=retry.if_exception_type(ConnectTimeout, InternalServerError), + ) def fetch_remote(self) -> bytes: """ Method for accessing url. @@ -135,7 +138,7 @@ def fetch_remote(self) -> bytes: response.raise_for_status() return response.content - @retry.Retry(deadline=30.) + @retry.Retry(deadline=30.0) def create_url(self, signer: Callable[[bytes], str]) -> str: """ Utility for creating a url from any of the other image representations. @@ -150,13 +153,14 @@ def create_url(self, signer: Callable[[bytes], str]) -> str: elif self.im_bytes is not None: self.url = signer(self.im_bytes) elif self.file_path is not None: - with open(self.file_path, 'rb') as file: + with open(self.file_path, "rb") as file: self.url = signer(file.read()) elif self.arr is not None: self.url = signer(self.np_to_bytes(self.arr)) else: raise ValueError( - "One of url, im_bytes, file_path, arr must not be None.") + "One of url, im_bytes, file_path, arr must not be None." + ) return self.url @model_validator(mode="after") @@ -167,7 +171,10 @@ def validate_args(self, values): arr = self.arr uid = self.uid global_key = self.global_key - if uid == file_path == im_bytes == url == global_key == None and arr is None: + if ( + uid == file_path == im_bytes == url == global_key == None + and arr is None + ): raise ValueError( "One of `file_path`, `im_bytes`, `url`, `uid`, `global_key` or `arr` required." ) @@ -179,15 +186,18 @@ def validate_args(self, values): elif len(arr.shape) != 3: raise ValueError( "unsupported image format. Must be 3D ([H,W,C])." - f"Use {self.__name__}.from_2D_arr to construct from 2D") + f"Use {self.__name__}.from_2D_arr to construct from 2D" + ) return self def __repr__(self) -> str: - symbol_or_none = lambda data: '...' if data is not None else None - return f"{self.__class__.__name__}(im_bytes={symbol_or_none(self.im_bytes)}," \ - f"file_path={self.file_path}," \ - f"url={self.url}," \ - f"arr={symbol_or_none(self.arr)})" + symbol_or_none = lambda data: "..." if data is not None else None + return ( + f"{self.__class__.__name__}(im_bytes={symbol_or_none(self.im_bytes)}," + f"file_path={self.file_path}," + f"url={self.url}," + f"arr={symbol_or_none(self.arr)})" + ) class MaskData(RasterData): @@ -212,5 +222,4 @@ class MaskData(RasterData): """ -class ImageData(RasterData, BaseData): - ... +class ImageData(RasterData, BaseData): ... diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/text.py b/libs/labelbox/src/labelbox/data/annotation_types/data/text.py index 20624c161..fe4c222d3 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/text.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/text.py @@ -22,6 +22,7 @@ class TextData(BaseData, _NoCoercionMixin): text (str) url (str) """ + class_name: Literal["TextData"] = "TextData" file_path: Optional[str] = None text: Optional[str] = None @@ -51,11 +52,12 @@ def value(self) -> str: raise ValueError("Must set either url, file_path or im_bytes") def set_fetch_fn(self, fn): - object.__setattr__(self, 'fetch_remote', lambda: fn(self)) + object.__setattr__(self, "fetch_remote", lambda: fn(self)) - @retry.Retry(deadline=15., - predicate=retry.if_exception_type(ConnectTimeout, - InternalServerError)) + @retry.Retry( + deadline=15.0, + predicate=retry.if_exception_type(ConnectTimeout, InternalServerError), + ) def fetch_remote(self) -> str: """ Method for accessing url. @@ -69,7 +71,7 @@ def fetch_remote(self) -> str: response.raise_for_status() return response.text - @retry.Retry(deadline=15.) + @retry.Retry(deadline=15.0) def create_url(self, signer: Callable[[bytes], str]) -> None: """ Utility for creating a url from any of the other text references. @@ -82,13 +84,14 @@ def create_url(self, signer: Callable[[bytes], str]) -> None: if self.url is not None: return self.url elif self.file_path is not None: - with open(self.file_path, 'rb') as file: + with open(self.file_path, "rb") as file: self.url = signer(file.read()) elif self.text is not None: self.url = signer(self.text.encode()) else: raise ValueError( - "One of url, im_bytes, file_path, numpy must not be None.") + "One of url, im_bytes, file_path, numpy must not be None." + ) return self.url @model_validator(mode="after") @@ -105,6 +108,8 @@ def validate_date(self, values): return self def __repr__(self) -> str: - return f"TextData(file_path={self.file_path}," \ - f"text={self.text[:30] + '...' if self.text is not None else None}," \ - f"url={self.url})" + return ( + f"TextData(file_path={self.file_path}," + f"text={self.text[:30] + '...' if self.text is not None else None}," + f"url={self.url})" + ) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/tiled_image.py b/libs/labelbox/src/labelbox/data/annotation_types/data/tiled_image.py index 5d3561ceb..adb8db549 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/tiled_image.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/tiled_image.py @@ -29,19 +29,20 @@ class EPSG(Enum): - """ Provides the EPSG for tiled image assets that are currently supported. - + """Provides the EPSG for tiled image assets that are currently supported. + SIMPLEPIXEL is Simple that can be used to obtain the pixel space coordinates >>> epsg = EPSG() """ + SIMPLEPIXEL = 1 EPSG4326 = 4326 EPSG3857 = 3857 class TiledBounds(BaseModel): - """ Bounds for a tiled image asset related to the relevant epsg. + """Bounds for a tiled image asset related to the relevant epsg. Bounds should be Point objects. @@ -51,21 +52,22 @@ class TiledBounds(BaseModel): Point(x=-99.20534818927473, y=19.400498983095076) ]) """ + epsg: EPSG bounds: List[Point] - @field_validator('bounds') + @field_validator("bounds") def validate_bounds_not_equal(cls, bounds): first_bound = bounds[0] second_bound = bounds[1] - if first_bound.x == second_bound.x or \ - first_bound.y == second_bound.y: + if first_bound.x == second_bound.x or first_bound.y == second_bound.y: raise ValueError( - f"Bounds on either axes cannot be equal, currently {bounds}") + f"Bounds on either axes cannot be equal, currently {bounds}" + ) return bounds - #validate bounds are within lat,lng range if they are EPSG4326 + # validate bounds are within lat,lng range if they are EPSG4326 @model_validator(mode="after") def validate_bounds_lat_lng(self): epsg = self.epsg @@ -74,16 +76,20 @@ def validate_bounds_lat_lng(self): if epsg == EPSG.EPSG4326: for bound in bounds: lat, lng = bound.y, bound.x - if int(lng) not in VALID_LNG_RANGE or int( - lat) not in VALID_LAT_RANGE: - raise ValueError(f"Invalid lat/lng bounds. Found {bounds}. " - f"lat must be in {VALID_LAT_RANGE}. " - f"lng must be in {VALID_LNG_RANGE}.") + if ( + int(lng) not in VALID_LNG_RANGE + or int(lat) not in VALID_LAT_RANGE + ): + raise ValueError( + f"Invalid lat/lng bounds. Found {bounds}. " + f"lat must be in {VALID_LAT_RANGE}. " + f"lng must be in {VALID_LNG_RANGE}." + ) return self class TileLayer(BaseModel): - """ Url that contains the tile layer. Must be in the format: + """Url that contains the tile layer. Must be in the format: https://c.tile.openstreetmap.org/{z}/{x}/{y}.png @@ -92,13 +98,14 @@ class TileLayer(BaseModel): name="slippy map tile" ) """ + url: str name: Optional[str] = "default" def asdict(self) -> Dict[str, str]: return {"tileLayerUrl": self.url, "name": self.name} - @field_validator('url') + @field_validator("url") def validate_url(cls, url): xyz_format = "/{z}/{x}/{y}" if xyz_format not in url: @@ -107,7 +114,7 @@ def validate_url(cls, url): class TiledImageData(BaseData): - """ Represents tiled imagery + """Represents tiled imagery If specified version is 2, converts bounds from [lng,lat] to [lat,lng] @@ -119,12 +126,13 @@ class TiledImageData(BaseData): max_native_zoom: int = None tile_size: Optional[int] version: int = 2 - alternative_layers: List[TileLayer] + alternative_layers: List[TileLayer] >>> tiled_image_data = TiledImageData(tile_layer=TileLayer, tile_bounds=TiledBounds, zoom_levels=[1, 12]) """ + tile_layer: TileLayer tile_bounds: TiledBounds alternative_layers: List[TileLayer] = [] @@ -141,9 +149,10 @@ def __post_init__(self) -> None: def asdict(self) -> Dict[str, str]: return { "tileLayerUrl": self.tile_layer.url, - "bounds": [[ - self.tile_bounds.bounds[0].x, self.tile_bounds.bounds[0].y - ], [self.tile_bounds.bounds[1].x, self.tile_bounds.bounds[1].y]], + "bounds": [ + [self.tile_bounds.bounds[0].x, self.tile_bounds.bounds[0].y], + [self.tile_bounds.bounds[1].x, self.tile_bounds.bounds[1].y], + ], "minZoom": self.zoom_levels[0], "maxZoom": self.zoom_levels[1], "maxNativeZoom": self.max_native_zoom, @@ -152,13 +161,12 @@ def asdict(self) -> Dict[str, str]: "alternativeLayers": [ layer.asdict() for layer in self.alternative_layers ], - "version": self.version + "version": self.version, } - def raster_data(self, - zoom: int = 0, - max_tiles: int = 32, - multithread=True) -> RasterData: + def raster_data( + self, zoom: int = 0, max_tiles: int = 32, multithread=True + ) -> RasterData: """Converts the tiled image asset into a RasterData object containing an np.ndarray. @@ -168,26 +176,33 @@ def raster_data(self, xstart, ystart, xend, yend = self._get_simple_image_params(zoom) elif self.tile_bounds.epsg == EPSG.EPSG4326: xstart, ystart, xend, yend = self._get_3857_image_params( - zoom, self.tile_bounds) + zoom, self.tile_bounds + ) elif self.tile_bounds.epsg == EPSG.EPSG3857: - #transform to 4326 + # transform to 4326 transformer = EPSGTransformer.create_geo_to_geo_transformer( - EPSG.EPSG3857, EPSG.EPSG4326) + EPSG.EPSG3857, EPSG.EPSG4326 + ) transforming_bounds = [ transformer(self.tile_bounds.bounds[0]), - transformer(self.tile_bounds.bounds[1]) + transformer(self.tile_bounds.bounds[1]), ] xstart, ystart, xend, yend = self._get_3857_image_params( - zoom, transforming_bounds) + zoom, transforming_bounds + ) else: raise ValueError(f"Unsupported epsg found: {self.tile_bounds.epsg}") self._validate_num_tiles(xstart, ystart, xend, yend, max_tiles) rounded_tiles, pixel_offsets = list( - zip(*[ - self._tile_to_pixel(pt) for pt in [xstart, ystart, xend, yend] - ])) + zip( + *[ + self._tile_to_pixel(pt) + for pt in [xstart, ystart, xend, yend] + ] + ) + ) image = self._fetch_image_for_bounds(*rounded_tiles, zoom, multithread) arr = self._crop_to_bounds(image, *pixel_offsets) @@ -195,13 +210,14 @@ def raster_data(self, @property def value(self) -> np.ndarray: - """Returns the value of a generated RasterData object. - """ - return self.raster_data(self.zoom_levels[0], - multithread=self.multithread).value - - def _get_simple_image_params(self, - zoom) -> Tuple[float, float, float, float]: + """Returns the value of a generated RasterData object.""" + return self.raster_data( + self.zoom_levels[0], multithread=self.multithread + ).value + + def _get_simple_image_params( + self, zoom + ) -> Tuple[float, float, float, float]: """Computes the x and y tile bounds for fetching an image that captures the entire labeling region (TiledData.bounds) given a specific zoom @@ -214,14 +230,16 @@ def _get_simple_image_params(self, self.tile_bounds.bounds[1].y, self.tile_bounds.bounds[0].y, ) - return (*[ - x * (2**(zoom)) / self.tile_size - for x in [xstart, ystart, xend, yend] - ],) + return ( + *[ + x * (2 ** (zoom)) / self.tile_size + for x in [xstart, ystart, xend, yend] + ], + ) def _get_3857_image_params( - self, zoom: int, - bounds: TiledBounds) -> Tuple[float, float, float, float]: + self, zoom: int, bounds: TiledBounds + ) -> Tuple[float, float, float, float]: """Computes the x and y tile bounds for fetching an image that captures the entire labeling region (TiledData.bounds) given a specific zoom """ @@ -237,10 +255,9 @@ def _get_3857_image_params( ystart, yend = min(ystart, yend), max(ystart, yend) return (*[pt * 2.0**zoom for pt in [xstart, ystart, xend, yend]],) - def _latlng_to_tile(self, - lat: float, - lng: float, - zoom=0) -> Tuple[float, float]: + def _latlng_to_tile( + self, lat: float, lng: float, zoom=0 + ) -> Tuple[float, float]: """Converts lat/lng to 3857 tile coordinates Formula found here: https://wiki.openstreetmap.org/wiki/Slippy_map_tilenames#lon.2Flat_to_tile_numbers_2 @@ -252,29 +269,31 @@ def _latlng_to_tile(self, return x, y def _tile_to_pixel(self, tile: float) -> Tuple[int, int]: - """Rounds a tile coordinate and reports the remainder in pixels - """ + """Rounds a tile coordinate and reports the remainder in pixels""" rounded_tile = int(tile) remainder = tile - rounded_tile pixel_offset = int(self.tile_size * remainder) return rounded_tile, pixel_offset - def _fetch_image_for_bounds(self, - x_tile_start: int, - y_tile_start: int, - x_tile_end: int, - y_tile_end: int, - zoom: int, - multithread=True) -> np.ndarray: - """Fetches the tiles and combines them into a single image. - + def _fetch_image_for_bounds( + self, + x_tile_start: int, + y_tile_start: int, + x_tile_end: int, + y_tile_end: int, + zoom: int, + multithread=True, + ) -> np.ndarray: + """Fetches the tiles and combines them into a single image. + If a tile cannot be fetched, a padding of expected tile size is instead added. """ if multithread: tiles = {} with ThreadPoolExecutor( - max_workers=TILE_DOWNLOAD_CONCURRENCY) as exc: + max_workers=TILE_DOWNLOAD_CONCURRENCY + ) as exc: for x in range(x_tile_start, x_tile_end + 1): for y in range(y_tile_start, y_tile_end + 1): tiles[(x, y)] = exc.submit(self._fetch_tile, x, y, zoom) @@ -290,8 +309,11 @@ def _fetch_image_for_bounds(self, row.append(self._fetch_tile(x, y, zoom)) except: row.append( - np.zeros(shape=(self.tile_size, self.tile_size, 3), - dtype=np.uint8)) + np.zeros( + shape=(self.tile_size, self.tile_size, 3), + dtype=np.uint8, + ) + ) rows.append(np.hstack(row)) return np.vstack(rows) @@ -331,19 +353,27 @@ def invert_point(pt): x_px_end, y_px_end = invert_point(x_px_end), invert_point(y_px_end) return image[y_px_start:y_px_end, x_px_start:x_px_end, :] - def _validate_num_tiles(self, xstart: float, ystart: float, xend: float, - yend: float, max_tiles: int): + def _validate_num_tiles( + self, + xstart: float, + ystart: float, + xend: float, + yend: float, + max_tiles: int, + ): """Calculates the number of expected tiles we would fetch. If this is greater than the number of max tiles, raise an error. """ total_n_tiles = (yend - ystart + 1) * (xend - xstart + 1) if total_n_tiles > max_tiles: - raise ValueError(f"Requested zoom results in {total_n_tiles} tiles." - f"Max allowed tiles are {max_tiles}" - f"Increase max tiles or reduce zoom level.") + raise ValueError( + f"Requested zoom results in {total_n_tiles} tiles." + f"Max allowed tiles are {max_tiles}" + f"Increase max tiles or reduce zoom level." + ) - @field_validator('zoom_levels') + @field_validator("zoom_levels") def validate_zoom_levels(cls, zoom_levels): if zoom_levels[0] > zoom_levels[1]: raise ValueError( @@ -356,8 +386,9 @@ class EPSGTransformer(BaseModel): """Transformer class between different EPSG's. Useful when wanting to project in different formats. """ + transformer: Any - model_config = ConfigDict(arbitrary_types_allowed = True) + model_config = ConfigDict(arbitrary_types_allowed=True) @staticmethod def _is_simple(epsg: EPSG) -> bool: @@ -366,7 +397,7 @@ def _is_simple(epsg: EPSG) -> bool: @staticmethod def _get_ranges(bounds: np.ndarray) -> Tuple[int, int]: """helper function to get the range between bounds. - + returns a tuple (x_range, y_range)""" x_range = np.max(bounds[:, 0]) - np.min(bounds[:, 0]) y_range = np.max(bounds[:, 1]) - np.min(bounds[:, 1]) @@ -374,90 +405,107 @@ def _get_ranges(bounds: np.ndarray) -> Tuple[int, int]: @staticmethod def _min_max_x_y(bounds: np.ndarray) -> Tuple[int, int, int, int]: - """returns the min x, max x, min y, max y of a numpy array - """ - return np.min(bounds[:, 0]), np.max(bounds[:, 0]), np.min( - bounds[:, 1]), np.max(bounds[:, 1]) + """returns the min x, max x, min y, max y of a numpy array""" + return ( + np.min(bounds[:, 0]), + np.max(bounds[:, 0]), + np.min(bounds[:, 1]), + np.max(bounds[:, 1]), + ) @classmethod - def geo_and_pixel(cls, - src_epsg, - pixel_bounds: TiledBounds, - geo_bounds: TiledBounds, - zoom=0) -> Callable: + def geo_and_pixel( + cls, + src_epsg, + pixel_bounds: TiledBounds, + geo_bounds: TiledBounds, + zoom=0, + ) -> Callable: """method to change from one projection to simple projection""" pixel_bounds = pixel_bounds.bounds geo_bounds_epsg = geo_bounds.epsg geo_bounds = geo_bounds.bounds - local_bounds = np.array([(point.x, point.y) for point in pixel_bounds], - dtype=int) - #convert geo bounds to pixel bounds. assumes geo bounds are in wgs84/EPS4326 per leaflet - global_bounds = np.array([ - PygeoPoint.from_latitude_longitude(latitude=point.y, - longitude=point.x).pixels(zoom) - for point in geo_bounds - ]) + local_bounds = np.array( + [(point.x, point.y) for point in pixel_bounds], dtype=int + ) + # convert geo bounds to pixel bounds. assumes geo bounds are in wgs84/EPS4326 per leaflet + global_bounds = np.array( + [ + PygeoPoint.from_latitude_longitude( + latitude=point.y, longitude=point.x + ).pixels(zoom) + for point in geo_bounds + ] + ) - #get the range of pixels for both sets of bounds to use as a multiplification factor + # get the range of pixels for both sets of bounds to use as a multiplification factor local_x_range, local_y_range = cls._get_ranges(bounds=local_bounds) global_x_range, global_y_range = cls._get_ranges(bounds=global_bounds) if src_epsg == EPSG.SIMPLEPIXEL: def transform(x: int, y: int) -> Callable[[int, int], Transformer]: - scaled_xy = (x * (global_x_range) / (local_x_range), - y * (global_y_range) / (local_y_range)) + scaled_xy = ( + x * (global_x_range) / (local_x_range), + y * (global_y_range) / (local_y_range), + ) minx, _, miny, _ = cls._min_max_x_y(bounds=global_bounds) x, y = map(lambda i, j: i + j, scaled_xy, (minx, miny)) - point = PygeoPoint.from_pixel(pixel_x=x, pixel_y=y, - zoom=zoom).latitude_longitude - #convert to the desired epsg - return Transformer.from_crs(EPSG.EPSG4326.value, - geo_bounds_epsg.value, - always_xy=True).transform( - point[1], point[0]) + point = PygeoPoint.from_pixel( + pixel_x=x, pixel_y=y, zoom=zoom + ).latitude_longitude + # convert to the desired epsg + return Transformer.from_crs( + EPSG.EPSG4326.value, geo_bounds_epsg.value, always_xy=True + ).transform(point[1], point[0]) return transform - #handles 4326 from lat,lng + # handles 4326 from lat,lng elif src_epsg == EPSG.EPSG4326: def transform(x: int, y: int) -> Callable[[int, int], Transformer]: point_in_px = PygeoPoint.from_latitude_longitude( - latitude=y, longitude=x).pixels(zoom) + latitude=y, longitude=x + ).pixels(zoom) minx, _, miny, _ = cls._min_max_x_y(global_bounds) x, y = map(lambda i, j: i - j, point_in_px, (minx, miny)) - return (x * (local_x_range) / (global_x_range), - y * (local_y_range) / (global_y_range)) + return ( + x * (local_x_range) / (global_x_range), + y * (local_y_range) / (global_y_range), + ) return transform - #handles 3857 from meters + # handles 3857 from meters elif src_epsg == EPSG.EPSG3857: def transform(x: int, y: int) -> Callable[[int, int], Transformer]: - point_in_px = PygeoPoint.from_meters(meter_y=y, - meter_x=x).pixels(zoom) + point_in_px = PygeoPoint.from_meters( + meter_y=y, meter_x=x + ).pixels(zoom) minx, _, miny, _ = cls._min_max_x_y(global_bounds) x, y = map(lambda i, j: i - j, point_in_px, (minx, miny)) - return (x * (local_x_range) / (global_x_range), - y * (local_y_range) / (global_y_range)) + return ( + x * (local_x_range) / (global_x_range), + y * (local_y_range) / (global_y_range), + ) return transform @classmethod def create_geo_to_geo_transformer( - cls, src_epsg: EPSG, - tgt_epsg: EPSG) -> Callable[[int, int], Transformer]: - """method to change from one projection to another projection. + cls, src_epsg: EPSG, tgt_epsg: EPSG + ) -> Callable[[int, int], Transformer]: + """method to change from one projection to another projection. supports EPSG transformations not Simple. """ @@ -466,36 +514,45 @@ def create_geo_to_geo_transformer( f"Cannot be used for Simple transformations. Found {src_epsg} and {tgt_epsg}" ) - return EPSGTransformer(transformer=Transformer.from_crs( - src_epsg.value, tgt_epsg.value, always_xy=True).transform) + return EPSGTransformer( + transformer=Transformer.from_crs( + src_epsg.value, tgt_epsg.value, always_xy=True + ).transform + ) @classmethod def create_geo_to_pixel_transformer( - cls, - src_epsg, - pixel_bounds: TiledBounds, - geo_bounds: TiledBounds, - zoom=0) -> Callable[[int, int], Transformer]: + cls, + src_epsg, + pixel_bounds: TiledBounds, + geo_bounds: TiledBounds, + zoom=0, + ) -> Callable[[int, int], Transformer]: """method to change from a geo projection to Simple""" - transform_function = cls.geo_and_pixel(src_epsg=src_epsg, - pixel_bounds=pixel_bounds, - geo_bounds=geo_bounds, - zoom=zoom) + transform_function = cls.geo_and_pixel( + src_epsg=src_epsg, + pixel_bounds=pixel_bounds, + geo_bounds=geo_bounds, + zoom=zoom, + ) return EPSGTransformer(transformer=transform_function) @classmethod def create_pixel_to_geo_transformer( - cls, - src_epsg, - pixel_bounds: TiledBounds, - geo_bounds: TiledBounds, - zoom=0) -> Callable[[int, int], Transformer]: + cls, + src_epsg, + pixel_bounds: TiledBounds, + geo_bounds: TiledBounds, + zoom=0, + ) -> Callable[[int, int], Transformer]: """method to change from a geo projection to Simple""" - transform_function = cls.geo_and_pixel(src_epsg=src_epsg, - pixel_bounds=pixel_bounds, - geo_bounds=geo_bounds, - zoom=zoom) + transform_function = cls.geo_and_pixel( + src_epsg=src_epsg, + pixel_bounds=pixel_bounds, + geo_bounds=geo_bounds, + zoom=zoom, + ) return EPSGTransformer(transformer=transform_function) def _get_point_obj(self, point) -> Point: @@ -513,9 +570,12 @@ def __call__( return Line(points=[self._get_point_obj(p) for p in shape.points]) if isinstance(shape, Polygon): return Polygon( - points=[self._get_point_obj(p) for p in shape.points]) + points=[self._get_point_obj(p) for p in shape.points] + ) if isinstance(shape, Rectangle): - return Rectangle(start=self._get_point_obj(shape.start), - end=self._get_point_obj(shape.end)) + return Rectangle( + start=self._get_point_obj(shape.start), + end=self._get_point_obj(shape.end), + ) else: - raise ValueError(f"Unsupported type found: {type(shape)}") \ No newline at end of file + raise ValueError(f"Unsupported type found: {type(shape)}") diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/video.py b/libs/labelbox/src/labelbox/data/annotation_types/data/video.py index 5d7804860..581801036 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/video.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/video.py @@ -21,11 +21,12 @@ class VideoData(BaseData): """ Represents video """ + file_path: Optional[str] = None url: Optional[str] = None - frames: Optional[Dict[int, TypedArray[Literal['uint8']]]] = None + frames: Optional[Dict[int, TypedArray[Literal["uint8"]]]] = None # Required for discriminating between data types - model_config = ConfigDict(extra = "forbid") + model_config = ConfigDict(extra="forbid") def load_frames(self, overwrite: bool = False) -> None: """ @@ -48,9 +49,7 @@ def value(self): return self.frame_generator() def frame_generator( - self, - cache_frames=False, - download_dir='/tmp' + self, cache_frames=False, download_dir="/tmp" ) -> Generator[Tuple[int, np.ndarray], None, None]: """ A generator for accessing individual frames in a video. @@ -91,9 +90,9 @@ def __getitem__(self, idx: int) -> np.ndarray: return self.frames[idx] def set_fetch_fn(self, fn): - object.__setattr__(self, 'fetch_remote', lambda: fn(self)) + object.__setattr__(self, "fetch_remote", lambda: fn(self)) - @retry.Retry(deadline=15.) + @retry.Retry(deadline=15.0) def fetch_remote(self, local_path) -> None: """ Method for downloading data from self.url @@ -106,7 +105,7 @@ def fetch_remote(self, local_path) -> None: """ urllib.request.urlretrieve(self.url, local_path) - @retry.Retry(deadline=15.) + @retry.Retry(deadline=15.0) def create_url(self, signer: Callable[[bytes], str]) -> None: """ Utility for creating a url from any of the other video references. @@ -119,7 +118,7 @@ def create_url(self, signer: Callable[[bytes], str]) -> None: if self.url is not None: return self.url elif self.file_path is not None: - with open(self.file_path, 'rb') as file: + with open(self.file_path, "rb") as file: self.url = signer(file.read()) elif self.frames is not None: self.file_path = self.frames_to_video(self.frames) @@ -128,10 +127,9 @@ def create_url(self, signer: Callable[[bytes], str]) -> None: raise ValueError("One of url, file_path, frames must not be None.") return self.url - def frames_to_video(self, - frames: Dict[int, np.ndarray], - fps=20, - save_dir='/tmp') -> str: + def frames_to_video( + self, frames: Dict[int, np.ndarray], fps=20, save_dir="/tmp" + ) -> str: """ Compresses the data by converting a set of individual frames to a single video. @@ -141,9 +139,12 @@ def frames_to_video(self, for key in frames.keys(): frame = frames[key] if out is None: - out = cv2.VideoWriter(file_path, - cv2.VideoWriter_fourcc(*'MP4V'), fps, - frame.shape[:2]) + out = cv2.VideoWriter( + file_path, + cv2.VideoWriter_fourcc(*"MP4V"), + fps, + frame.shape[:2], + ) out.write(frame) if out is None: return @@ -165,6 +166,8 @@ def validate_data(self): return self def __repr__(self) -> str: - return f"VideoData(file_path={self.file_path}," \ - f"frames={'...' if self.frames is not None else None}," \ - f"url={self.url})" + return ( + f"VideoData(file_path={self.file_path}," + f"frames={'...' if self.frames is not None else None}," + f"url={self.url})" + ) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/feature.py b/libs/labelbox/src/labelbox/data/annotation_types/feature.py index 836817aeb..5b4591abc 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/feature.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/feature.py @@ -9,6 +9,7 @@ class FeatureSchema(BaseModel): Could be a annotation, a subclass, or an option. Schema ids might not be known when constructing these objects so both a name and schema id are valid. """ + name: Optional[str] = None feature_schema_id: Optional[Cuid] = None diff --git a/libs/labelbox/src/labelbox/data/annotation_types/geometry/geometry.py b/libs/labelbox/src/labelbox/data/annotation_types/geometry/geometry.py index acdfa94c2..7b5b42cd5 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/geometry/geometry.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/geometry/geometry.py @@ -9,23 +9,34 @@ class Geometry(BaseModel, ABC): - """Abstract base class for geometry objects - """ + """Abstract base class for geometry objects""" + extra: Dict[str, Any] = {} @property def shapely( - self - ) -> Union[geom.Point, geom.LineString, geom.Polygon, geom.MultiPoint, - geom.MultiLineString, geom.MultiPolygon]: + self, + ) -> Union[ + geom.Point, + geom.LineString, + geom.Polygon, + geom.MultiPoint, + geom.MultiLineString, + geom.MultiPolygon, + ]: return geom.shape(self.geometry) - def get_or_create_canvas(self, height: Optional[int], width: Optional[int], - canvas: Optional[np.ndarray]) -> np.ndarray: + def get_or_create_canvas( + self, + height: Optional[int], + width: Optional[int], + canvas: Optional[np.ndarray], + ) -> np.ndarray: if canvas is None: if height is None or width is None: raise ValueError( - "Must either provide canvas or height and width") + "Must either provide canvas or height and width" + ) canvas = np.zeros((height, width, 3), dtype=np.uint8) canvas = np.ascontiguousarray(canvas) return canvas @@ -36,10 +47,12 @@ def geometry(self) -> geojson: pass @abstractmethod - def draw(self, - height: Optional[int] = None, - width: Optional[int] = None, - canvas: Optional[np.ndarray] = None, - color: Optional[Union[int, Tuple[int, int, int]]] = None, - thickness: Optional[int] = 1) -> np.ndarray: + def draw( + self, + height: Optional[int] = None, + width: Optional[int] = None, + canvas: Optional[np.ndarray] = None, + color: Optional[Union[int, Tuple[int, int, int]]] = None, + thickness: Optional[int] = 1, + ) -> np.ndarray: pass diff --git a/libs/labelbox/src/labelbox/data/annotation_types/geometry/line.py b/libs/labelbox/src/labelbox/data/annotation_types/geometry/line.py index fcd31b4e7..d8ea52f0c 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/geometry/line.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/geometry/line.py @@ -11,6 +11,7 @@ from pydantic import field_validator + class Line(Geometry): """Line annotation @@ -20,30 +21,36 @@ class Line(Geometry): >>> Line(points = [Point(x=3,y=4), Point(x=3,y=5)]) """ + points: List[Point] @property def geometry(self) -> geojson.MultiLineString: return geojson.MultiLineString( - [[[point.x, point.y] for point in self.points]]) + [[[point.x, point.y] for point in self.points]] + ) @classmethod def from_shapely(cls, shapely_obj: SLineString) -> "Line": """Transforms a shapely object.""" if not isinstance(shapely_obj, SLineString): raise TypeError( - f"Expected Shapely Line. Got {shapely_obj.geom_type}") + f"Expected Shapely Line. Got {shapely_obj.geom_type}" + ) - obj_coords = shapely_obj.__geo_interface__['coordinates'] + obj_coords = shapely_obj.__geo_interface__["coordinates"] return Line( - points=[Point(x=coords[0], y=coords[1]) for coords in obj_coords]) - - def draw(self, - height: Optional[int] = None, - width: Optional[int] = None, - canvas: Optional[np.ndarray] = None, - color: Union[int, Tuple[int, int, int]] = (255, 255, 255), - thickness: int = 1) -> np.ndarray: + points=[Point(x=coords[0], y=coords[1]) for coords in obj_coords] + ) + + def draw( + self, + height: Optional[int] = None, + width: Optional[int] = None, + canvas: Optional[np.ndarray] = None, + color: Union[int, Tuple[int, int, int]] = (255, 255, 255), + thickness: int = 1, + ) -> np.ndarray: """ Draw the line onto a 3d mask Args: @@ -57,14 +64,12 @@ def draw(self, numpy array representing the mask with the line drawn on it. """ canvas = self.get_or_create_canvas(height, width, canvas) - pts = np.array(self.geometry['coordinates']).astype(np.int32) - return cv2.polylines(canvas, - pts, - False, - color=color, - thickness=thickness) - - @field_validator('points') + pts = np.array(self.geometry["coordinates"]).astype(np.int32) + return cv2.polylines( + canvas, pts, False, color=color, thickness=thickness + ) + + @field_validator("points") def is_geom_valid(cls, points): if len(points) < 2: raise ValueError( diff --git a/libs/labelbox/src/labelbox/data/annotation_types/geometry/mask.py b/libs/labelbox/src/labelbox/data/annotation_types/geometry/mask.py index 39051182f..0d870f24f 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/geometry/mask.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/geometry/mask.py @@ -40,21 +40,22 @@ class Mask(Geometry): @property def geometry(self) -> Dict[str, Tuple[int, int, int]]: mask = self.draw(color=1) - contours, hierarchy = cv2.findContours(image=mask, - mode=cv2.RETR_TREE, - method=cv2.CHAIN_APPROX_NONE) + contours, hierarchy = cv2.findContours( + image=mask, mode=cv2.RETR_TREE, method=cv2.CHAIN_APPROX_NONE + ) holes = [] external_contours = [] for i in range(len(contours)): if hierarchy[0, i, 3] != -1: - #determined to be a hole based on contour hierarchy + # determined to be a hole based on contour hierarchy holes.append(contours[i]) else: external_contours.append(contours[i]) external_polygons = self._extract_polygons_from_contours( - external_contours) + external_contours + ) holes = self._extract_polygons_from_contours(holes) if not external_polygons.is_valid: @@ -65,12 +66,14 @@ def geometry(self) -> Dict[str, Tuple[int, int, int]]: return external_polygons.difference(holes).__geo_interface__ - def draw(self, - height: Optional[int] = None, - width: Optional[int] = None, - canvas: Optional[np.ndarray] = None, - color: Optional[Union[int, Tuple[int, int, int]]] = None, - thickness=None) -> np.ndarray: + def draw( + self, + height: Optional[int] = None, + width: Optional[int] = None, + canvas: Optional[np.ndarray] = None, + color: Optional[Union[int, Tuple[int, int, int]]] = None, + thickness=None, + ) -> np.ndarray: """Converts the Mask object into a numpy array Args: @@ -91,16 +94,20 @@ def draw(self, mask = np.alltrue(mask == self.color, axis=2).astype(np.uint8) if height is not None or width is not None: - mask = cv2.resize(mask, - (width or mask.shape[1], height or mask.shape[0])) + mask = cv2.resize( + mask, (width or mask.shape[1], height or mask.shape[0]) + ) dims = [mask.shape[0], mask.shape[1]] color = color or self.color if isinstance(color, (tuple, list)): dims = dims + [len(color)] - canvas = canvas if canvas is not None else np.zeros(tuple(dims), - dtype=np.uint8) + canvas = ( + canvas + if canvas is not None + else np.zeros(tuple(dims), dtype=np.uint8) + ) canvas[mask.astype(bool)] = color return canvas @@ -122,7 +129,7 @@ def create_url(self, signer: Callable[[bytes], str]) -> str: """ return self.mask.create_url(signer) - @field_validator('color') + @field_validator("color") def is_valid_color(cls, color): if isinstance(color, (tuple, list)): if len(color) == 1: @@ -137,6 +144,7 @@ def is_valid_color(cls, color): ) elif not (0 <= color <= 255): raise ValueError( - f"All rgb colors must be between 0 and 255. Found : {color}") + f"All rgb colors must be between 0 and 255. Found : {color}" + ) return color diff --git a/libs/labelbox/src/labelbox/data/annotation_types/geometry/point.py b/libs/labelbox/src/labelbox/data/annotation_types/geometry/point.py index c3f736e76..c801628f9 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/geometry/point.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/geometry/point.py @@ -18,6 +18,7 @@ class Point(Geometry): y (float) """ + x: float y: float @@ -30,17 +31,20 @@ def from_shapely(cls, shapely_obj: SPoint) -> "Point": """Transforms a shapely object.""" if not isinstance(shapely_obj, SPoint): raise TypeError( - f"Expected Shapely Point. Got {shapely_obj.geom_type}") + f"Expected Shapely Point. Got {shapely_obj.geom_type}" + ) - obj_coords = shapely_obj.__geo_interface__['coordinates'] + obj_coords = shapely_obj.__geo_interface__["coordinates"] return Point(x=obj_coords[0], y=obj_coords[1]) - def draw(self, - height: Optional[int] = None, - width: Optional[int] = None, - canvas: Optional[np.ndarray] = None, - color: Union[int, Tuple[int, int, int]] = (255, 255, 255), - thickness: int = 10) -> np.ndarray: + def draw( + self, + height: Optional[int] = None, + width: Optional[int] = None, + canvas: Optional[np.ndarray] = None, + color: Union[int, Tuple[int, int, int]] = (255, 255, 255), + thickness: int = 10, + ) -> np.ndarray: """ Draw the point onto a 3d mask Args: @@ -54,7 +58,10 @@ def draw(self, numpy array representing the mask with the point drawn on it. """ canvas = self.get_or_create_canvas(height, width, canvas) - return cv2.circle(canvas, (int(self.x), int(self.y)), - radius=thickness, - color=color, - thickness=-1) + return cv2.circle( + canvas, + (int(self.x), int(self.y)), + radius=thickness, + color=color, + thickness=-1, + ) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/geometry/polygon.py b/libs/labelbox/src/labelbox/data/annotation_types/geometry/polygon.py index 96e1f0c94..9785e7ab4 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/geometry/polygon.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/geometry/polygon.py @@ -25,6 +25,7 @@ class Polygon(Geometry): point is added to close it. """ + points: List[Point] @property @@ -36,20 +37,24 @@ def geometry(self) -> geojson.Polygon: @classmethod def from_shapely(cls, shapely_obj: SPolygon) -> "Polygon": """Transforms a shapely object.""" - #we only consider 0th index because we only allow for filled polygons + # we only consider 0th index because we only allow for filled polygons if not isinstance(shapely_obj, SPolygon): raise TypeError( - f"Expected Shapely Polygon. Got {shapely_obj.geom_type}") - obj_coords = shapely_obj.__geo_interface__['coordinates'][0] + f"Expected Shapely Polygon. Got {shapely_obj.geom_type}" + ) + obj_coords = shapely_obj.__geo_interface__["coordinates"][0] return Polygon( - points=[Point(x=coords[0], y=coords[1]) for coords in obj_coords]) - - def draw(self, - height: Optional[int] = None, - width: Optional[int] = None, - canvas: Optional[np.ndarray] = None, - color: Union[int, Tuple[int, int, int]] = (255, 255, 255), - thickness: int = -1) -> np.ndarray: + points=[Point(x=coords[0], y=coords[1]) for coords in obj_coords] + ) + + def draw( + self, + height: Optional[int] = None, + width: Optional[int] = None, + canvas: Optional[np.ndarray] = None, + color: Union[int, Tuple[int, int, int]] = (255, 255, 255), + thickness: int = -1, + ) -> np.ndarray: """ Draw the polygon onto a 3d mask Args: @@ -63,12 +68,12 @@ def draw(self, numpy array representing the mask with the polygon drawn on it. """ canvas = self.get_or_create_canvas(height, width, canvas) - pts = np.array(self.geometry['coordinates']).astype(np.int32) + pts = np.array(self.geometry["coordinates"]).astype(np.int32) if thickness == -1: return cv2.fillPoly(canvas, pts, color) return cv2.polylines(canvas, pts, True, color, thickness) - @field_validator('points') + @field_validator("points") def is_geom_valid(cls, points): if len(points) < 3: raise ValueError( diff --git a/libs/labelbox/src/labelbox/data/annotation_types/geometry/rectangle.py b/libs/labelbox/src/labelbox/data/annotation_types/geometry/rectangle.py index 3c43d44ba..5cabf0957 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/geometry/rectangle.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/geometry/rectangle.py @@ -20,43 +20,52 @@ class Rectangle(Geometry): start (Point): Top left coordinate of the rectangle end (Point): Bottom right coordinate of the rectangle """ + start: Point end: Point @property def geometry(self) -> geojson.geometry.Geometry: - return geojson.Polygon([[ - [self.start.x, self.start.y], - [self.start.x, self.end.y], - [self.end.x, self.end.y], - [self.end.x, self.start.y], - [self.start.x, self.start.y], - ]]) + return geojson.Polygon( + [ + [ + [self.start.x, self.start.y], + [self.start.x, self.end.y], + [self.end.x, self.end.y], + [self.end.x, self.start.y], + [self.start.x, self.start.y], + ] + ] + ) @classmethod def from_shapely(cls, shapely_obj: SPolygon) -> "Rectangle": """Transforms a shapely object. - + If the provided shape is a non-rectangular polygon, a rectangle will be returned based on the min and max x,y values.""" if not isinstance(shapely_obj, SPolygon): raise TypeError( - f"Expected Shapely Polygon. Got {shapely_obj.geom_type}") + f"Expected Shapely Polygon. Got {shapely_obj.geom_type}" + ) min_x, min_y, max_x, max_y = shapely_obj.bounds start = [min_x, min_y] end = [max_x, max_y] - return Rectangle(start=Point(x=start[0], y=start[1]), - end=Point(x=end[0], y=end[1])) - - def draw(self, - height: Optional[int] = None, - width: Optional[int] = None, - canvas: Optional[np.ndarray] = None, - color: Union[int, Tuple[int, int, int]] = (255, 255, 255), - thickness: int = -1) -> np.ndarray: + return Rectangle( + start=Point(x=start[0], y=start[1]), end=Point(x=end[0], y=end[1]) + ) + + def draw( + self, + height: Optional[int] = None, + width: Optional[int] = None, + canvas: Optional[np.ndarray] = None, + color: Union[int, Tuple[int, int, int]] = (255, 255, 255), + thickness: int = -1, + ) -> np.ndarray: """ Draw the rectangle onto a 3d mask Args: @@ -70,7 +79,7 @@ def draw(self, numpy array representing the mask with the rectangle drawn on it. """ canvas = self.get_or_create_canvas(height, width, canvas) - pts = np.array(self.geometry['coordinates']).astype(np.int32) + pts = np.array(self.geometry["coordinates"]).astype(np.int32) if thickness == -1: return cv2.fillPoly(canvas, pts, color) return cv2.polylines(canvas, pts, True, color, thickness) @@ -82,9 +91,9 @@ def from_xyhw(cls, x: float, y: float, h: float, w: float) -> "Rectangle": class RectangleUnit(Enum): - INCHES = 'INCHES' - PIXELS = 'PIXELS' - POINTS = 'POINTS' + INCHES = "INCHES" + PIXELS = "PIXELS" + POINTS = "POINTS" class DocumentRectangle(Rectangle): @@ -103,5 +112,6 @@ class DocumentRectangle(Rectangle): page (int): Page number of the document unit (RectangleUnits): Units of the rectangle """ + page: int unit: RectangleUnit diff --git a/libs/labelbox/src/labelbox/data/annotation_types/label.py b/libs/labelbox/src/labelbox/data/annotation_types/label.py index 973e9260f..c21a0ef8c 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/label.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/label.py @@ -3,14 +3,28 @@ import warnings import labelbox -from labelbox.data.annotation_types.data.generic_data_row_data import GenericDataRowData +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) from labelbox.data.annotation_types.data.tiled_image import TiledImageData from labelbox.schema import ontology from .annotation import ClassificationAnnotation, ObjectAnnotation from .relationship import RelationshipAnnotation from .llm_prompt_response.prompt import PromptClassificationAnnotation from .classification import ClassificationAnswer -from .data import AudioData, ConversationData, DicomData, DocumentData, HTMLData, ImageData, TextData, VideoData, LlmPromptCreationData, LlmPromptResponseCreationData, LlmResponseCreationData +from .data import ( + AudioData, + ConversationData, + DicomData, + DocumentData, + HTMLData, + ImageData, + TextData, + VideoData, + LlmPromptCreationData, + LlmPromptResponseCreationData, + LlmResponseCreationData, +) from .geometry import Mask from .metrics import ScalarMetric, ConfusionMatrixMetric from .types import Cuid @@ -20,10 +34,21 @@ from ..ontology import get_feature_schema_lookup from pydantic import BaseModel, field_validator, model_serializer -DataType = Union[VideoData, ImageData, TextData, TiledImageData, AudioData, - ConversationData, DicomData, DocumentData, HTMLData, - LlmPromptCreationData, LlmPromptResponseCreationData, - LlmResponseCreationData, GenericDataRowData] +DataType = Union[ + VideoData, + ImageData, + TextData, + TiledImageData, + AudioData, + ConversationData, + DicomData, + DocumentData, + HTMLData, + LlmPromptCreationData, + LlmPromptResponseCreationData, + LlmResponseCreationData, + GenericDataRowData, +] class Label(BaseModel): @@ -41,17 +66,26 @@ class Label(BaseModel): Args: uid: Optional Label Id in Labelbox - data: Data of Label, Image, Video, Text or dict with a single key uid | global_key | external_id. + data: Data of Label, Image, Video, Text or dict with a single key uid | global_key | external_id. Note use of classes as data is deprecated. Use GenericDataRowData or dict with a single key instead. annotations: List of Annotations in the label extra: additional context """ + uid: Optional[Cuid] = None data: DataType - annotations: List[Union[ClassificationAnnotation, ObjectAnnotation, - VideoMaskAnnotation, ScalarMetric, - ConfusionMatrixMetric, RelationshipAnnotation, - PromptClassificationAnnotation, MessageEvaluationTaskAnnotation]] = [] + annotations: List[ + Union[ + ClassificationAnnotation, + ObjectAnnotation, + VideoMaskAnnotation, + ScalarMetric, + ConfusionMatrixMetric, + RelationshipAnnotation, + PromptClassificationAnnotation, + MessageEvaluationTaskAnnotation, + ] + ] = [] extra: Dict[str, Any] = {} is_benchmark_reference: Optional[bool] = False @@ -64,7 +98,8 @@ def validate_data(cls, data): else: warnings.warn( f"Using {type(data).__name__} class for label.data is deprecated. " - "Use a dict or an instance of GenericDataRowData instead.") + "Use a dict or an instance of GenericDataRowData instead." + ) return data def object_annotations(self) -> List[ObjectAnnotation]: @@ -75,18 +110,20 @@ def classification_annotations(self) -> List[ClassificationAnnotation]: def _get_annotations_by_type(self, annotation_type): return [ - annot for annot in self.annotations + annot + for annot in self.annotations if isinstance(annot, annotation_type) ] def frame_annotations( - self + self, ) -> Dict[str, Union[VideoObjectAnnotation, VideoClassificationAnnotation]]: frame_dict = defaultdict(list) for annotation in self.annotations: if isinstance( - annotation, - (VideoObjectAnnotation, VideoClassificationAnnotation)): + annotation, + (VideoObjectAnnotation, VideoClassificationAnnotation), + ): frame_dict[annotation.frame].append(annotation) return frame_dict @@ -128,8 +165,9 @@ def add_url_to_masks(self, signer) -> "Label": mask.create_url(signer) return self - def create_data_row(self, dataset: "labelbox.Dataset", - signer: Callable[[bytes], str]) -> "Label": + def create_data_row( + self, dataset: "labelbox.Dataset", signer: Callable[[bytes], str] + ) -> "Label": """ Creates a data row and adds to the given dataset. Updates the label's data object to have the same external_id and uid as the data row. @@ -140,9 +178,9 @@ def create_data_row(self, dataset: "labelbox.Dataset", Returns: Label with updated references to new data row """ - args = {'row_data': self.data.create_url(signer)} + args = {"row_data": self.data.create_url(signer)} if self.data.external_id is not None: - args.update({'external_id': self.data.external_id}) + args.update({"external_id": self.data.external_id}) if self.data.uid is None: data_row = dataset.create_data_row(**args) @@ -151,7 +189,8 @@ def create_data_row(self, dataset: "labelbox.Dataset", return self def assign_feature_schema_ids( - self, ontology_builder: ontology.OntologyBuilder) -> "Label": + self, ontology_builder: ontology.OntologyBuilder + ) -> "Label": """ Adds schema ids to all FeatureSchema objects in the Labels. @@ -162,11 +201,14 @@ def assign_feature_schema_ids( Note: You can now import annotations using names directly without having to lookup schema_ids """ - warnings.warn("This method is deprecated and will be " - "removed in a future release. Feature schema ids" - " are no longer required for importing.") + warnings.warn( + "This method is deprecated and will be " + "removed in a future release. Feature schema ids" + " are no longer required for importing." + ) tool_lookup, classification_lookup = get_feature_schema_lookup( - ontology_builder) + ontology_builder + ) for annotation in self.annotations: if isinstance(annotation, ClassificationAnnotation): self._assign_or_raise(annotation, classification_lookup) @@ -178,7 +220,8 @@ def assign_feature_schema_ids( self._assign_option(classification, classification_lookup) else: raise TypeError( - f"Unexpected type found for annotation. {type(annotation)}") + f"Unexpected type found for annotation. {type(annotation)}" + ) return self def _assign_or_raise(self, annotation, lookup: Dict[str, str]) -> None: @@ -187,12 +230,15 @@ def _assign_or_raise(self, annotation, lookup: Dict[str, str]) -> None: feature_schema_id = lookup.get(annotation.name) if feature_schema_id is None: - raise ValueError(f"No tool matches name {annotation.name}. " - f"Must be one of {list(lookup.keys())}.") + raise ValueError( + f"No tool matches name {annotation.name}. " + f"Must be one of {list(lookup.keys())}." + ) annotation.feature_schema_id = feature_schema_id - def _assign_option(self, classification: ClassificationAnnotation, - lookup: Dict[str, str]) -> None: + def _assign_option( + self, classification: ClassificationAnnotation, lookup: Dict[str, str] + ) -> None: if isinstance(classification.value.answer, str): pass elif isinstance(classification.value.answer, ClassificationAnswer): @@ -207,10 +253,14 @@ def _assign_option(self, classification: ClassificationAnnotation, @field_validator("annotations", mode="before") def validate_union(cls, value): - supported = tuple([ - field - for field in get_args(get_args(cls.model_fields['annotations'].annotation)[0]) - ]) + supported = tuple( + [ + field + for field in get_args( + get_args(cls.model_fields["annotations"].annotation)[0] + ) + ] + ) if not isinstance(value, list): raise TypeError(f"Annotations must be a list. Found {type(value)}") prompt_count = 0 @@ -224,5 +274,6 @@ def validate_union(cls, value): prompt_count += 1 if prompt_count > 1: raise TypeError( - f"Only one prompt annotation is allowed per label") + f"Only one prompt annotation is allowed per label" + ) return value diff --git a/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/__init__.py index 7c0b63abc..4f4c0ee0e 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/__init__.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/__init__.py @@ -1,2 +1,2 @@ from .prompt import PromptText -from .prompt import PromptClassificationAnnotation \ No newline at end of file +from .prompt import PromptClassificationAnnotation diff --git a/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py index 98c0e7a69..b5a7e4fe5 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/llm_prompt_response/prompt.py @@ -4,7 +4,7 @@ class PromptText(ConfidenceMixin, CustomMetricsMixin, BaseModel): - """ Prompt text for LLM data generation + """Prompt text for LLM data generation >>> PromptText(answer = "some text answer", >>> confidence = 0.5, @@ -14,11 +14,13 @@ class PromptText(ConfidenceMixin, CustomMetricsMixin, BaseModel): >>> "value": 0.1 >>> }]) """ + answer: str -class PromptClassificationAnnotation(BaseAnnotation, ConfidenceMixin, - CustomMetricsMixin): +class PromptClassificationAnnotation( + BaseAnnotation, ConfidenceMixin, CustomMetricsMixin +): """Prompt annotation (non localized) >>> PromptClassificationAnnotation( @@ -30,6 +32,6 @@ class PromptClassificationAnnotation(BaseAnnotation, ConfidenceMixin, name (Optional[str]) feature_schema_id (Optional[Cuid]) value (Union[Text]) - """ + """ - value: PromptText \ No newline at end of file + value: PromptText diff --git a/libs/labelbox/src/labelbox/data/annotation_types/metrics/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/metrics/__init__.py index 2c7e45178..37750dd1f 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/metrics/__init__.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/metrics/__init__.py @@ -1,2 +1,6 @@ from .scalar import ScalarMetric, ScalarMetricAggregation, ScalarMetricValue -from .confusion_matrix import ConfusionMatrixMetric, ConfusionMatrixAggregation, ConfusionMatrixMetricValue +from .confusion_matrix import ( + ConfusionMatrixMetric, + ConfusionMatrixAggregation, + ConfusionMatrixMetricValue, +) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/metrics/base.py b/libs/labelbox/src/labelbox/data/annotation_types/metrics/base.py index 7c0636f48..0a4773a41 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/metrics/base.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/metrics/base.py @@ -1,7 +1,13 @@ from abc import ABC from typing import Dict, Optional, Any, Union -from pydantic import confloat, BaseModel, model_serializer, field_validator, error_wrappers +from pydantic import ( + confloat, + BaseModel, + model_serializer, + field_validator, + error_wrappers, +) from pydantic_core import ValidationError, InitErrorDetails ConfidenceValue = confloat(ge=0, le=1) @@ -21,15 +27,15 @@ def serialize_model(self, handler): res = handler(self) return {k: v for k, v in res.items() if v is not None} - - @field_validator('value') + @field_validator("value") def validate_value(cls, value): if isinstance(value, Dict): - if not (MIN_CONFIDENCE_SCORES <= len(value) <= - MAX_CONFIDENCE_SCORES): + if not ( + MIN_CONFIDENCE_SCORES <= len(value) <= MAX_CONFIDENCE_SCORES + ): raise ValueError( - f"Number of confidence scores must be greater than\n \ + f"Number of confidence scores must be greater than\n \ or equal to {MIN_CONFIDENCE_SCORES} and less than\n \ or equal to {MAX_CONFIDENCE_SCORES}. Found {len(value)}" - ) + ) return value diff --git a/libs/labelbox/src/labelbox/data/annotation_types/metrics/confusion_matrix.py b/libs/labelbox/src/labelbox/data/annotation_types/metrics/confusion_matrix.py index 4a346b8f4..30e2d2ed4 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/metrics/confusion_matrix.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/metrics/confusion_matrix.py @@ -9,8 +9,9 @@ Count = conint(ge=0, le=1e10) ConfusionMatrixMetricValue = Tuple[Count, Count, Count, Count] -ConfusionMatrixMetricConfidenceValue = Dict[ConfidenceValue, - ConfusionMatrixMetricValue] +ConfusionMatrixMetricConfidenceValue = Dict[ + ConfidenceValue, ConfusionMatrixMetricValue +] class ConfusionMatrixAggregation(Enum): @@ -18,7 +19,7 @@ class ConfusionMatrixAggregation(Enum): class ConfusionMatrixMetric(BaseMetric): - """ Class representing confusion matrix metrics. + """Class representing confusion matrix metrics. In the editor, this provides precision, recall, and f-scores. This should be used over multiple scalar metrics so that aggregations are accurate. @@ -28,7 +29,11 @@ class ConfusionMatrixMetric(BaseMetric): aggregation cannot be adjusted for confusion matrix metrics. """ + metric_name: str - value: Union[ConfusionMatrixMetricValue, - ConfusionMatrixMetricConfidenceValue] - aggregation: Optional[ConfusionMatrixAggregation] = ConfusionMatrixAggregation.CONFUSION_MATRIX + value: Union[ + ConfusionMatrixMetricValue, ConfusionMatrixMetricConfidenceValue + ] + aggregation: Optional[ConfusionMatrixAggregation] = ( + ConfusionMatrixAggregation.CONFUSION_MATRIX + ) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py b/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py index 560d6dcef..13d0e9748 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py @@ -18,30 +18,41 @@ class ScalarMetricAggregation(Enum): SUM = "SUM" -RESERVED_METRIC_NAMES = ('true_positive_count', 'false_positive_count', - 'true_negative_count', 'false_negative_count', - 'precision', 'recall', 'f1', 'iou') +RESERVED_METRIC_NAMES = ( + "true_positive_count", + "false_positive_count", + "true_negative_count", + "false_negative_count", + "precision", + "recall", + "f1", + "iou", +) class ScalarMetric(BaseMetric): - """ Class representing scalar metrics + """Class representing scalar metrics For backwards compatibility, metric_name is optional. The metric_name will be set to a default name in the editor if it is not set. This is not recommended and support for empty metric_name fields will be removed. aggregation will be ignored without providing a metric name. """ + metric_name: Optional[str] = None value: Union[ScalarMetricValue, ScalarMetricConfidenceValue] - aggregation: Optional[ - ScalarMetricAggregation] = ScalarMetricAggregation.ARITHMETIC_MEAN + aggregation: Optional[ScalarMetricAggregation] = ( + ScalarMetricAggregation.ARITHMETIC_MEAN + ) - @field_validator('metric_name') + @field_validator("metric_name") def validate_metric_name(cls, name: Union[str, None]): if name is None: return None clean_name = name.lower().strip() if clean_name in RESERVED_METRIC_NAMES: - raise ValueError(f"`{clean_name}` is a reserved metric name. " - "Please provide another value for `metric_name`.") + raise ValueError( + f"`{clean_name}` is a reserved metric name. " + "Please provide another value for `metric_name`." + ) return name diff --git a/libs/labelbox/src/labelbox/data/annotation_types/mmc.py b/libs/labelbox/src/labelbox/data/annotation_types/mmc.py index d3ab763cb..e2ed74d41 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/mmc.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/mmc.py @@ -10,7 +10,7 @@ class MessageInfo(_CamelCaseMixin): message_id: str model_config_name: str - + model_config = ConfigDict(protected_namespaces=()) @@ -21,7 +21,7 @@ class OrderedMessageInfo(MessageInfo): class _BaseMessageEvaluationTask(_CamelCaseMixin, ABC): format: ClassVar[str] parent_message_id: str - + model_config = ConfigDict(protected_namespaces=()) @@ -48,5 +48,8 @@ def _validate_ranked_messages(cls, v: List[OrderedMessageInfo]): class MessageEvaluationTaskAnnotation(BaseAnnotation): - value: Union[MessageSingleSelectionTask, MessageMultiSelectionTask, - MessageRankingTask] + value: Union[ + MessageSingleSelectionTask, + MessageMultiSelectionTask, + MessageRankingTask, + ] diff --git a/libs/labelbox/src/labelbox/data/annotation_types/ner/conversation_entity.py b/libs/labelbox/src/labelbox/data/annotation_types/ner/conversation_entity.py index 53b9059b9..e8bd49b56 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/ner/conversation_entity.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/ner/conversation_entity.py @@ -3,5 +3,6 @@ class ConversationEntity(TextEntity, _CamelCaseMixin): - """ Represents a text entity """ - message_id: str \ No newline at end of file + """Represents a text entity""" + + message_id: str diff --git a/libs/labelbox/src/labelbox/data/annotation_types/ner/document_entity.py b/libs/labelbox/src/labelbox/data/annotation_types/ner/document_entity.py index c2acecd7c..6a5abec23 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/ner/document_entity.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/ner/document_entity.py @@ -17,5 +17,6 @@ def validate_page(cls, v): class DocumentEntity(_CamelCaseMixin, BaseModel): - """ Represents a text entity """ + """Represents a text entity""" + text_selections: List[DocumentTextSelection] diff --git a/libs/labelbox/src/labelbox/data/annotation_types/ner/text_entity.py b/libs/labelbox/src/labelbox/data/annotation_types/ner/text_entity.py index 60764f759..ece341434 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/ner/text_entity.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/ner/text_entity.py @@ -4,16 +4,17 @@ class TextEntity(BaseModel): - """ Represents a text entity """ + """Represents a text entity""" + start: int end: int extra: Dict[str, Any] = {} @model_validator(mode="after") def validate_start_end(self, values): - if hasattr(self, 'start') and hasattr(self, 'end'): - if (isinstance(self.start, int) and - self.start > self.end): + if hasattr(self, "start") and hasattr(self, "end"): + if isinstance(self.start, int) and self.start > self.end: raise ValueError( - "Location end must be greater or equal to start") + "Location end must be greater or equal to start" + ) return self diff --git a/libs/labelbox/src/labelbox/data/annotation_types/relationship.py b/libs/labelbox/src/labelbox/data/annotation_types/relationship.py index 27a833830..b65f21d16 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/relationship.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/relationship.py @@ -1,10 +1,12 @@ from pydantic import BaseModel from enum import Enum -from labelbox.data.annotation_types.annotation import BaseAnnotation, ObjectAnnotation +from labelbox.data.annotation_types.annotation import ( + BaseAnnotation, + ObjectAnnotation, +) class Relationship(BaseModel): - class Type(Enum): UNIDIRECTIONAL = "unidirectional" BIDIRECTIONAL = "bidirectional" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/types.py b/libs/labelbox/src/labelbox/data/annotation_types/types.py index b26789aae..0a9793f8f 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/types.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/types.py @@ -9,12 +9,11 @@ Cuid = Annotated[str, StringConstraints(min_length=25, max_length=25)] -DType = TypeVar('DType') -DShape = TypeVar('DShape') +DType = TypeVar("DType") +DShape = TypeVar("DShape") class _TypedArray(np.ndarray, Generic[DType, DShape]): - @classmethod def __get_validators__(cls): yield cls.validate @@ -26,15 +25,21 @@ def validate(cls, val, field: Field): return val -if version.parse(np.__version__) >= version.parse('1.25.0'): +if version.parse(np.__version__) >= version.parse("1.25.0"): from typing import GenericAlias + TypedArray = GenericAlias(_TypedArray, (Any, DType)) -elif version.parse(np.__version__) >= version.parse('1.23.0'): +elif version.parse(np.__version__) >= version.parse("1.23.0"): from numpy._typing import _GenericAlias + TypedArray = _GenericAlias(_TypedArray, (Any, DType)) -elif version.parse('1.22.0') <= version.parse( - np.__version__) < version.parse('1.23.0'): +elif ( + version.parse("1.22.0") + <= version.parse(np.__version__) + < version.parse("1.23.0") +): from numpy.typing import _GenericAlias + TypedArray = _GenericAlias(_TypedArray, (Any, DType)) else: TypedArray = _TypedArray[Any, DType] diff --git a/libs/labelbox/src/labelbox/data/annotation_types/video.py b/libs/labelbox/src/labelbox/data/annotation_types/video.py index 79a14ec2d..cfebd7a1f 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/video.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/video.py @@ -1,13 +1,30 @@ from enum import Enum from typing import List, Optional, Tuple -from labelbox.data.annotation_types.annotation import ClassificationAnnotation, ObjectAnnotation - -from labelbox.data.annotation_types.annotation import ClassificationAnnotation, ObjectAnnotation +from labelbox.data.annotation_types.annotation import ( + ClassificationAnnotation, + ObjectAnnotation, +) + +from labelbox.data.annotation_types.annotation import ( + ClassificationAnnotation, + ObjectAnnotation, +) from labelbox.data.annotation_types.feature import FeatureSchema -from labelbox.data.mixins import ConfidenceNotSupportedMixin, CustomMetricsNotSupportedMixin +from labelbox.data.mixins import ( + ConfidenceNotSupportedMixin, + CustomMetricsNotSupportedMixin, +) from labelbox.utils import _CamelCaseMixin, is_valid_uri -from pydantic import model_validator, BaseModel, field_validator, model_serializer, Field, ConfigDict, AliasChoices +from pydantic import ( + model_validator, + BaseModel, + field_validator, + model_serializer, + Field, + ConfigDict, + AliasChoices, +) class VideoClassificationAnnotation(ClassificationAnnotation): @@ -20,12 +37,16 @@ class VideoClassificationAnnotation(ClassificationAnnotation): segment_id (Optional[Int]): Index of video segment this annotation belongs to extra (Dict[str, Any]) """ + frame: int segment_index: Optional[int] = None -class VideoObjectAnnotation(ObjectAnnotation, ConfidenceNotSupportedMixin, - CustomMetricsNotSupportedMixin): +class VideoObjectAnnotation( + ObjectAnnotation, + ConfidenceNotSupportedMixin, + CustomMetricsNotSupportedMixin, +): """Video object annotation >>> VideoObjectAnnotation( >>> keyframe=True, @@ -46,14 +67,15 @@ class VideoObjectAnnotation(ObjectAnnotation, ConfidenceNotSupportedMixin, classifications (List[ClassificationAnnotation]) = [] extra (Dict[str, Any]) """ + frame: int keyframe: bool segment_index: Optional[int] = None class GroupKey(Enum): - """Group key for DICOM annotations - """ + """Group key for DICOM annotations""" + AXIAL = "axial" SAGITTAL = "sagittal" CORONAL = "coronal" @@ -84,14 +106,19 @@ class DICOMObjectAnnotation(VideoObjectAnnotation): classifications (List[ClassificationAnnotation]) = [] extra (Dict[str, Any]) """ + group_key: GroupKey class MaskFrame(_CamelCaseMixin, BaseModel): index: int - instance_uri: Optional[str] = Field(default=None, validation_alias=AliasChoices("instanceURI", "instanceUri"), serialization_alias="instanceURI") + instance_uri: Optional[str] = Field( + default=None, + validation_alias=AliasChoices("instanceURI", "instanceUri"), + serialization_alias="instanceURI", + ) im_bytes: Optional[bytes] = None - + model_config = ConfigDict(populate_by_name=True) @model_validator(mode="after") @@ -110,43 +137,49 @@ def validate_uri(cls, v): class MaskInstance(_CamelCaseMixin, FeatureSchema): - color_rgb: Tuple[int, int, int] = Field(validation_alias=AliasChoices("colorRGB", "colorRgb"), serialization_alias="colorRGB") + color_rgb: Tuple[int, int, int] = Field( + validation_alias=AliasChoices("colorRGB", "colorRgb"), + serialization_alias="colorRGB", + ) name: str model_config = ConfigDict(populate_by_name=True) + class VideoMaskAnnotation(BaseModel): """Video mask annotation - >>> VideoMaskAnnotation( - >>> frames=[ - >>> MaskFrame(index=1, instance_uri='https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA'), - >>> MaskFrame(index=5, instance_uri='https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys1%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA'), - >>> ], - >>> instances=[ - >>> MaskInstance(color_rgb=(0, 0, 255), name="mask1"), - >>> MaskInstance(color_rgb=(0, 255, 0), name="mask2"), - >>> MaskInstance(color_rgb=(255, 0, 0), name="mask3") - >>> ] - >>> ) + >>> VideoMaskAnnotation( + >>> frames=[ + >>> MaskFrame(index=1, instance_uri='https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA'), + >>> MaskFrame(index=5, instance_uri='https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys1%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA'), + >>> ], + >>> instances=[ + >>> MaskInstance(color_rgb=(0, 0, 255), name="mask1"), + >>> MaskInstance(color_rgb=(0, 255, 0), name="mask2"), + >>> MaskInstance(color_rgb=(255, 0, 0), name="mask3") + >>> ] + >>> ) """ + frames: List[MaskFrame] instances: List[MaskInstance] class DICOMMaskAnnotation(VideoMaskAnnotation): """DICOM mask annotation - >>> DICOMMaskAnnotation( - >>> name="dicom_mask", - >>> group_key=GroupKey.AXIAL, - >>> frames=[ - >>> MaskFrame(index=1, instance_uri='https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA'), - >>> MaskFrame(index=5, instance_uri='https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys1%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA'), - >>> ], - >>> instances=[ - >>> MaskInstance(color_rgb=(0, 0, 255), name="mask1"), - >>> MaskInstance(color_rgb=(0, 255, 0), name="mask2"), - >>> MaskInstance(color_rgb=(255, 0, 0), name="mask3") - >>> ] - >>> ) + >>> DICOMMaskAnnotation( + >>> name="dicom_mask", + >>> group_key=GroupKey.AXIAL, + >>> frames=[ + >>> MaskFrame(index=1, instance_uri='https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA'), + >>> MaskFrame(index=5, instance_uri='https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys1%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA'), + >>> ], + >>> instances=[ + >>> MaskInstance(color_rgb=(0, 0, 255), name="mask1"), + >>> MaskInstance(color_rgb=(0, 255, 0), name="mask2"), + >>> MaskInstance(color_rgb=(255, 0, 0), name="mask3") + >>> ] + >>> ) """ + group_key: GroupKey diff --git a/libs/labelbox/src/labelbox/data/generator.py b/libs/labelbox/src/labelbox/data/generator.py index 891dc1315..8270a6715 100644 --- a/libs/labelbox/src/labelbox/data/generator.py +++ b/libs/labelbox/src/labelbox/data/generator.py @@ -13,9 +13,7 @@ class ThreadSafeGen: """ def __init__(self, iterable: Iterable[Any]): - """ - - """ + """ """ self.iterable = iterable self.lock = threading.Lock() @@ -70,7 +68,8 @@ def fill_queue(self): self.queue.put(value) except Exception as e: self.queue.put( - ValueError(f"Unexpected exception while filling queue: {e}")) + ValueError(f"Unexpected exception while filling queue: {e}") + ) finally: self.queue.put(None) diff --git a/libs/labelbox/src/labelbox/data/metrics/__init__.py b/libs/labelbox/src/labelbox/data/metrics/__init__.py index f99fc85a8..7085b772e 100644 --- a/libs/labelbox/src/labelbox/data/metrics/__init__.py +++ b/libs/labelbox/src/labelbox/data/metrics/__init__.py @@ -1,2 +1,5 @@ -from .confusion_matrix import confusion_matrix_metric, feature_confusion_matrix_metric +from .confusion_matrix import ( + confusion_matrix_metric, + feature_confusion_matrix_metric, +) from .iou import miou_metric, feature_miou_metric diff --git a/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/calculation.py b/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/calculation.py index 1b1fc801b..938e17f65 100644 --- a/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/calculation.py +++ b/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/calculation.py @@ -2,20 +2,37 @@ import numpy as np -from ..iou.calculation import _get_mask_pairs, _get_vector_pairs, _get_ner_pairs, miou -from ...annotation_types import (ObjectAnnotation, ClassificationAnnotation, - Mask, Geometry, Checklist, Radio, TextEntity, - ScalarMetricValue, ConfusionMatrixMetricValue) -from ..group import (get_feature_pairs, get_identifying_key, has_no_annotations, - has_no_matching_annotations) - - -def confusion_matrix(ground_truths: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - include_subclasses: bool, - iou: float) -> ConfusionMatrixMetricValue: +from ..iou.calculation import ( + _get_mask_pairs, + _get_vector_pairs, + _get_ner_pairs, + miou, +) +from ...annotation_types import ( + ObjectAnnotation, + ClassificationAnnotation, + Mask, + Geometry, + Checklist, + Radio, + TextEntity, + ScalarMetricValue, + ConfusionMatrixMetricValue, +) +from ..group import ( + get_feature_pairs, + get_identifying_key, + has_no_annotations, + has_no_matching_annotations, +) + + +def confusion_matrix( + ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], + predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], + include_subclasses: bool, + iou: float, +) -> ConfusionMatrixMetricValue: """ Computes the confusion matrix for an arbitrary set of ground truth and predicted annotations. It first computes the confusion matrix for each metric and then sums across all classes @@ -33,8 +50,9 @@ def confusion_matrix(ground_truths: List[Union[ObjectAnnotation, annotation_pairs = get_feature_pairs(ground_truths, predictions) conf_matrix = [ - feature_confusion_matrix(annotation_pair[0], annotation_pair[1], - include_subclasses, iou) + feature_confusion_matrix( + annotation_pair[0], annotation_pair[1], include_subclasses, iou + ) for annotation_pair in annotation_pairs.values() ] matrices = [matrix for matrix in conf_matrix if matrix is not None] @@ -42,10 +60,11 @@ def confusion_matrix(ground_truths: List[Union[ObjectAnnotation, def feature_confusion_matrix( - ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], - include_subclasses: bool, - iou: float) -> Optional[ConfusionMatrixMetricValue]: + ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], + predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], + include_subclasses: bool, + iou: float, +) -> Optional[ConfusionMatrixMetricValue]: """ Computes confusion matrix for all features of the same class. @@ -63,24 +82,28 @@ def feature_confusion_matrix( elif has_no_annotations(ground_truths, predictions): return None elif isinstance(predictions[0].value, Mask): - return mask_confusion_matrix(ground_truths, predictions, - include_subclasses, iou) + return mask_confusion_matrix( + ground_truths, predictions, include_subclasses, iou + ) elif isinstance(predictions[0].value, Geometry): - return vector_confusion_matrix(ground_truths, predictions, - include_subclasses, iou) + return vector_confusion_matrix( + ground_truths, predictions, include_subclasses, iou + ) elif isinstance(predictions[0].value, TextEntity): - return ner_confusion_matrix(ground_truths, predictions, - include_subclasses, iou) + return ner_confusion_matrix( + ground_truths, predictions, include_subclasses, iou + ) elif isinstance(predictions[0], ClassificationAnnotation): return classification_confusion_matrix(ground_truths, predictions) else: raise ValueError( - f"Unexpected annotation found. Found {type(predictions[0].value)}") + f"Unexpected annotation found. Found {type(predictions[0].value)}" + ) def classification_confusion_matrix( - ground_truths: List[ClassificationAnnotation], - predictions: List[ClassificationAnnotation] + ground_truths: List[ClassificationAnnotation], + predictions: List[ClassificationAnnotation], ) -> ConfusionMatrixMetricValue: """ Computes iou score for all features with the same feature schema id. @@ -97,9 +120,11 @@ def classification_confusion_matrix( if has_no_matching_annotations(ground_truths, predictions): return [0, len(predictions), 0, len(ground_truths)] - elif has_no_annotations( - ground_truths, - predictions) or len(predictions) > 1 or len(ground_truths) > 1: + elif ( + has_no_annotations(ground_truths, predictions) + or len(predictions) > 1 + or len(ground_truths) > 1 + ): # Note that we could return [0,0,0,0] but that will bloat the imports for no reason return None @@ -108,7 +133,8 @@ def classification_confusion_matrix( if type(prediction) != type(ground_truth): raise TypeError( "Classification features must be the same type to compute agreement. " - f"Found `{type(prediction)}` and `{type(ground_truth)}`") + f"Found `{type(prediction)}` and `{type(ground_truth)}`" + ) if isinstance(prediction.value, Radio): return radio_confusion_matrix(ground_truth.value, prediction.value) @@ -120,11 +146,13 @@ def classification_confusion_matrix( ) -def vector_confusion_matrix(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation], - include_subclasses: bool, - iou: float, - buffer=70.) -> Optional[ConfusionMatrixMetricValue]: +def vector_confusion_matrix( + ground_truths: List[ObjectAnnotation], + predictions: List[ObjectAnnotation], + include_subclasses: bool, + iou: float, + buffer=70.0, +) -> Optional[ConfusionMatrixMetricValue]: """ Computes confusion matrix for any vector class (point, polygon, line, rectangle). Ground truths and predictions should all belong to the same class. @@ -149,11 +177,11 @@ def vector_confusion_matrix(ground_truths: List[ObjectAnnotation], return object_pair_confusion_matrix(pairs, include_subclasses, iou) -def object_pair_confusion_matrix(pairs: List[Tuple[ObjectAnnotation, - ObjectAnnotation, - ScalarMetricValue]], - include_subclasses: bool, - iou: float) -> ConfusionMatrixMetricValue: +def object_pair_confusion_matrix( + pairs: List[Tuple[ObjectAnnotation, ObjectAnnotation, ScalarMetricValue]], + include_subclasses: bool, + iou: float, +) -> ConfusionMatrixMetricValue: """ Computes the confusion matrix for a list of object annotation pairs. Performs greedy matching of pairs. @@ -177,14 +205,22 @@ def object_pair_confusion_matrix(pairs: List[Tuple[ObjectAnnotation, prediction_ids.add(prediction_id) ground_truth_ids.add(ground_truth_id) - if agreement > iou and \ - prediction_id not in matched_predictions and \ - ground_truth_id not in matched_ground_truths: - if include_subclasses and (ground_truth.classifications or - prediction.classifications): - if miou(ground_truth.classifications, + if ( + agreement > iou + and prediction_id not in matched_predictions + and ground_truth_id not in matched_ground_truths + ): + if include_subclasses and ( + ground_truth.classifications or prediction.classifications + ): + if ( + miou( + ground_truth.classifications, prediction.classifications, - include_subclasses=False) < 1.: + include_subclasses=False, + ) + < 1.0 + ): # Incorrect if the subclasses don't 100% agree then there is no match continue matched_predictions.add(prediction_id) @@ -198,8 +234,9 @@ def object_pair_confusion_matrix(pairs: List[Tuple[ObjectAnnotation, return [tps, fps, tns, fns] -def radio_confusion_matrix(ground_truth: Radio, - prediction: Radio) -> ConfusionMatrixMetricValue: +def radio_confusion_matrix( + ground_truth: Radio, prediction: Radio +) -> ConfusionMatrixMetricValue: """ Calculates confusion between ground truth and predicted radio values @@ -220,8 +257,8 @@ def radio_confusion_matrix(ground_truth: Radio, def checklist_confusion_matrix( - ground_truth: Checklist, - prediction: Checklist) -> ConfusionMatrixMetricValue: + ground_truth: Checklist, prediction: Checklist +) -> ConfusionMatrixMetricValue: """ Calculates agreement between ground truth and predicted checklist items: @@ -246,10 +283,12 @@ def checklist_confusion_matrix( return [tps, fps, 0, fns] -def mask_confusion_matrix(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation], - include_subclasses: bool, - iou: float) -> Optional[ScalarMetricValue]: +def mask_confusion_matrix( + ground_truths: List[ObjectAnnotation], + predictions: List[ObjectAnnotation], + include_subclasses: bool, + iou: float, +) -> Optional[ScalarMetricValue]: """ Computes confusion matrix metric for two masks @@ -269,15 +308,17 @@ def mask_confusion_matrix(ground_truths: List[ObjectAnnotation], return None pairs = _get_mask_pairs(ground_truths, predictions) - return object_pair_confusion_matrix(pairs, - include_subclasses=include_subclasses, - iou=iou) + return object_pair_confusion_matrix( + pairs, include_subclasses=include_subclasses, iou=iou + ) -def ner_confusion_matrix(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation], - include_subclasses: bool, - iou: float) -> Optional[ConfusionMatrixMetricValue]: +def ner_confusion_matrix( + ground_truths: List[ObjectAnnotation], + predictions: List[ObjectAnnotation], + include_subclasses: bool, + iou: float, +) -> Optional[ConfusionMatrixMetricValue]: """Computes confusion matrix metric between two lists of TextEntity objects Args: diff --git a/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/confusion_matrix.py b/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/confusion_matrix.py index 19caab426..6d817b105 100644 --- a/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/confusion_matrix.py +++ b/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/confusion_matrix.py @@ -3,8 +3,11 @@ from labelbox.data.annotation_types import feature from labelbox.data.annotation_types.metrics import ConfusionMatrixMetric from typing import List, Optional, Union -from ...annotation_types import (Label, ObjectAnnotation, - ClassificationAnnotation) +from ...annotation_types import ( + Label, + ObjectAnnotation, + ClassificationAnnotation, +) from ..group import get_feature_pairs from .calculation import confusion_matrix @@ -12,12 +15,12 @@ import numpy as np -def confusion_matrix_metric(ground_truths: List[Union[ - ObjectAnnotation, ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - include_subclasses=False, - iou=0.5) -> List[ConfusionMatrixMetric]: +def confusion_matrix_metric( + ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], + predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], + include_subclasses=False, + iou=0.5, +) -> List[ConfusionMatrixMetric]: """ Computes confusion matrix metrics between two sets of annotations. These annotations should relate to the same data (image/video). @@ -31,11 +34,12 @@ def confusion_matrix_metric(ground_truths: List[Union[ Returns: Returns a list of ConfusionMatrixMetrics. Will be empty if there were no predictions and labels. Otherwise a single metric will be returned. """ - if not (0. < iou < 1.): + if not (0.0 < iou < 1.0): raise ValueError("iou must be between 0 and 1") - value = confusion_matrix(ground_truths, predictions, include_subclasses, - iou) + value = confusion_matrix( + ground_truths, predictions, include_subclasses, iou + ) # If both gt and preds are empty there is no metric if value is None: return [] @@ -68,39 +72,45 @@ def feature_confusion_matrix_metric( annotation_pairs = get_feature_pairs(ground_truths, predictions) metrics = [] for key in annotation_pairs: - value = feature_confusion_matrix(annotation_pairs[key][0], - annotation_pairs[key][1], - include_subclasses, iou) + value = feature_confusion_matrix( + annotation_pairs[key][0], + annotation_pairs[key][1], + include_subclasses, + iou, + ) if value is None: continue - metric_name = _get_metric_name(annotation_pairs[key][0], - annotation_pairs[key][1], iou) + metric_name = _get_metric_name( + annotation_pairs[key][0], annotation_pairs[key][1], iou + ) metrics.append( - ConfusionMatrixMetric(metric_name=metric_name, - feature_name=key, - value=value)) + ConfusionMatrixMetric( + metric_name=metric_name, feature_name=key, value=value + ) + ) return metrics -def _get_metric_name(ground_truths: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - iou: float): - +def _get_metric_name( + ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], + predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], + iou: float, +): if _is_classification(ground_truths, predictions): return "classification" return f"{int(iou*100)}pct_iou" -def _is_classification(ground_truths: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]]): +def _is_classification( + ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], + predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], +): # Check if either the prediction or label contains a classification annotation - return (len(predictions) and - isinstance(predictions[0], ClassificationAnnotation) or - len(ground_truths) and - isinstance(ground_truths[0], ClassificationAnnotation)) + return ( + len(predictions) + and isinstance(predictions[0], ClassificationAnnotation) + or len(ground_truths) + and isinstance(ground_truths[0], ClassificationAnnotation) + ) diff --git a/libs/labelbox/src/labelbox/data/metrics/group.py b/libs/labelbox/src/labelbox/data/metrics/group.py index 5579ac9ce..88f4eae8b 100644 --- a/libs/labelbox/src/labelbox/data/metrics/group.py +++ b/libs/labelbox/src/labelbox/data/metrics/group.py @@ -1,11 +1,18 @@ """ Tools for grouping features and labels so that we can compute metrics on the individual groups """ + from collections import defaultdict from typing import Dict, List, Tuple, Union from labelbox.data.annotation_types.annotation import ClassificationAnnotation -from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnswer, Radio, Text +from labelbox.data.annotation_types.classification.classification import ( + Checklist, + ClassificationAnswer, + Radio, + Text, +) + try: from typing import Literal except ImportError: @@ -17,7 +24,7 @@ def get_identifying_key( features_a: List[FeatureSchema], features_b: List[FeatureSchema] -) -> Union[Literal['name'], Literal['feature_schema_id']]: +) -> Union[Literal["name"], Literal["feature_schema_id"]]: """ Checks to make sure that features in both sets contain the same type of identifying keys. This can either be the feature name or feature schema id. @@ -30,22 +37,24 @@ def get_identifying_key( """ all_schema_ids_defined_pred, all_names_defined_pred = all_have_key( - features_a) - if (not all_schema_ids_defined_pred and not all_names_defined_pred): + features_a + ) + if not all_schema_ids_defined_pred and not all_names_defined_pred: raise ValueError("All data must have feature_schema_ids or names set") all_schema_ids_defined_gt, all_names_defined_gt = all_have_key(features_b) # Prefer name becuse the user will be able to know what it means # Schema id incase that doesn't exist. - if (all_names_defined_pred and all_names_defined_gt): - return 'name' + if all_names_defined_pred and all_names_defined_gt: + return "name" elif all_schema_ids_defined_pred and all_schema_ids_defined_gt: - return 'feature_schema_id' + return "feature_schema_id" else: raise ValueError( "Ground truth and prediction annotations must have set all name or feature ids. " - "Otherwise there is no key to match on. Please update.") + "Otherwise there is no key to match on. Please update." + ) def all_have_key(features: List[FeatureSchema]) -> Tuple[bool, bool]: @@ -79,10 +88,9 @@ def all_have_key(features: List[FeatureSchema]) -> Tuple[bool, bool]: return all_schemas, all_names -def get_label_pairs(labels_a: list, - labels_b: list, - match_on="uid", - filter_mismatch=False) -> Dict[str, Tuple[Label, Label]]: +def get_label_pairs( + labels_a: list, labels_b: list, match_on="uid", filter_mismatch=False +) -> Dict[str, Tuple[Label, Label]]: """ This is a function to pairing a list of prediction labels and a list of ground truth labels easier. There are a few potentiall problems with this function. @@ -101,7 +109,7 @@ def get_label_pairs(labels_a: list, """ - if match_on not in ['uid', 'external_id']: + if match_on not in ["uid", "external_id"]: raise ValueError("Can only match on `uid` or `exteranl_id`.") label_lookup_a = { @@ -147,9 +155,10 @@ def get_feature_pairs( """ identifying_key = get_identifying_key(features_a, features_b) - lookup_a, lookup_b = _create_feature_lookup( - features_a, - identifying_key), _create_feature_lookup(features_b, identifying_key) + lookup_a, lookup_b = ( + _create_feature_lookup(features_a, identifying_key), + _create_feature_lookup(features_b, identifying_key), + ) keys = set(lookup_a.keys()).union(set(lookup_b.keys())) result = defaultdict(list) @@ -158,8 +167,9 @@ def get_feature_pairs( return result -def _create_feature_lookup(features: List[FeatureSchema], - key: str) -> Dict[str, List[FeatureSchema]]: +def _create_feature_lookup( + features: List[FeatureSchema], key: str +) -> Dict[str, List[FeatureSchema]]: """ Groups annotation by name (if available otherwise feature schema id). @@ -172,29 +182,33 @@ def _create_feature_lookup(features: List[FeatureSchema], grouped_features = defaultdict(list) for feature in features: if isinstance(feature, ClassificationAnnotation): - #checklists + # checklists if isinstance(feature.value, Checklist): for answer in feature.value.answer: new_answer = Radio(answer=answer) new_annotation = ClassificationAnnotation( value=new_answer, name=answer.name, - feature_schema_id=answer.feature_schema_id) + feature_schema_id=answer.feature_schema_id, + ) - grouped_features[getattr(answer, - key)].append(new_annotation) + grouped_features[getattr(answer, key)].append( + new_annotation + ) elif isinstance(feature.value, Text): grouped_features[getattr(feature, key)].append(feature) else: - grouped_features[getattr(feature.value.answer, - key)].append(feature) + grouped_features[getattr(feature.value.answer, key)].append( + feature + ) else: grouped_features[getattr(feature, key)].append(feature) return grouped_features -def has_no_matching_annotations(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation]): +def has_no_matching_annotations( + ground_truths: List[ObjectAnnotation], predictions: List[ObjectAnnotation] +): if len(ground_truths) and not len(predictions): # No existing predictions but existing ground truths means no matches. return True @@ -204,6 +218,7 @@ def has_no_matching_annotations(ground_truths: List[ObjectAnnotation], return False -def has_no_annotations(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation]): +def has_no_annotations( + ground_truths: List[ObjectAnnotation], predictions: List[ObjectAnnotation] +): return not len(ground_truths) and not len(predictions) diff --git a/libs/labelbox/src/labelbox/data/metrics/iou/calculation.py b/libs/labelbox/src/labelbox/data/metrics/iou/calculation.py index e25035c1b..2a376d3fe 100644 --- a/libs/labelbox/src/labelbox/data/metrics/iou/calculation.py +++ b/libs/labelbox/src/labelbox/data/metrics/iou/calculation.py @@ -4,15 +4,32 @@ import numpy as np from shapely.geometry import Polygon -from ..group import get_feature_pairs, get_identifying_key, has_no_annotations, has_no_matching_annotations -from ...annotation_types import (ObjectAnnotation, ClassificationAnnotation, - Mask, Geometry, Point, Line, Checklist, Text, - TextEntity, Radio, ScalarMetricValue) - - -def miou(ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], - include_subclasses: bool) -> Optional[ScalarMetricValue]: +from ..group import ( + get_feature_pairs, + get_identifying_key, + has_no_annotations, + has_no_matching_annotations, +) +from ...annotation_types import ( + ObjectAnnotation, + ClassificationAnnotation, + Mask, + Geometry, + Point, + Line, + Checklist, + Text, + TextEntity, + Radio, + ScalarMetricValue, +) + + +def miou( + ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], + predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], + include_subclasses: bool, +) -> Optional[ScalarMetricValue]: """ Computes miou for an arbitrary set of ground truth and predicted annotations. It first computes the iou for each metric and then takes the average (weighting each class equally) @@ -35,11 +52,11 @@ def miou(ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], return None if not len(ious) else np.mean(ious) -def feature_miou(ground_truths: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - include_subclasses: bool) -> Optional[ScalarMetricValue]: +def feature_miou( + ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], + predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], + include_subclasses: bool, +) -> Optional[ScalarMetricValue]: """ Computes iou score for all features of the same class. @@ -52,7 +69,7 @@ def feature_miou(ground_truths: List[Union[ObjectAnnotation, float representing the iou score for the feature type if score can be computed otherwise None. """ if has_no_matching_annotations(ground_truths, predictions): - return 0. + return 0.0 elif has_no_annotations(ground_truths, predictions): return None elif isinstance(predictions[0].value, Mask): @@ -65,13 +82,16 @@ def feature_miou(ground_truths: List[Union[ObjectAnnotation, return ner_miou(ground_truths, predictions, include_subclasses) else: raise ValueError( - f"Unexpected annotation found. Found {type(predictions[0].value)}") + f"Unexpected annotation found. Found {type(predictions[0].value)}" + ) -def vector_miou(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation], - include_subclasses: bool, - buffer=70.) -> Optional[ScalarMetricValue]: +def vector_miou( + ground_truths: List[ObjectAnnotation], + predictions: List[ObjectAnnotation], + include_subclasses: bool, + buffer=70.0, +) -> Optional[ScalarMetricValue]: """ Computes iou score for all features with the same feature schema id. Calculation includes subclassifications. @@ -84,44 +104,57 @@ def vector_miou(ground_truths: List[ObjectAnnotation], If there are no matches then this returns none """ if has_no_matching_annotations(ground_truths, predictions): - return 0. + return 0.0 elif has_no_annotations(ground_truths, predictions): return None pairs = _get_vector_pairs(ground_truths, predictions, buffer=buffer) return object_pair_miou(pairs, include_subclasses) -def object_pair_miou(pairs: List[Tuple[ObjectAnnotation, ObjectAnnotation, - ScalarMetricValue]], - include_subclasses) -> ScalarMetricValue: +def object_pair_miou( + pairs: List[Tuple[ObjectAnnotation, ObjectAnnotation, ScalarMetricValue]], + include_subclasses, +) -> ScalarMetricValue: pairs.sort(key=lambda triplet: triplet[2], reverse=True) solution_agreements = [] solution_features = set() all_features = set() for prediction, ground_truth, agreement in pairs: all_features.update({id(prediction), id(ground_truth)}) - if id(prediction) not in solution_features and id( - ground_truth) not in solution_features: + if ( + id(prediction) not in solution_features + and id(ground_truth) not in solution_features + ): solution_features.update({id(prediction), id(ground_truth)}) if include_subclasses: - classification_iou = miou(prediction.classifications, - ground_truth.classifications, - include_subclasses=False) - classification_iou = classification_iou if classification_iou is not None else agreement + classification_iou = miou( + prediction.classifications, + ground_truth.classifications, + include_subclasses=False, + ) + classification_iou = ( + classification_iou + if classification_iou is not None + else agreement + ) solution_agreements.append( - (agreement + classification_iou) / 2.) + (agreement + classification_iou) / 2.0 + ) else: solution_agreements.append(agreement) # Add zeros for unmatched Features - solution_agreements.extend([0.0] * - (len(all_features) - len(solution_features))) + solution_agreements.extend( + [0.0] * (len(all_features) - len(solution_features)) + ) return np.mean(solution_agreements) -def mask_miou(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation], - include_subclasses: bool) -> Optional[ScalarMetricValue]: +def mask_miou( + ground_truths: List[ObjectAnnotation], + predictions: List[ObjectAnnotation], + include_subclasses: bool, +) -> Optional[ScalarMetricValue]: """ Computes iou score for all features with the same feature schema id. Calculation includes subclassifications. @@ -133,7 +166,7 @@ def mask_miou(ground_truths: List[ObjectAnnotation], float representing the iou score for the masks """ if has_no_matching_annotations(ground_truths, predictions): - return 0. + return 0.0 elif has_no_annotations(ground_truths, predictions): return None @@ -141,22 +174,26 @@ def mask_miou(ground_truths: List[ObjectAnnotation], pairs = _get_mask_pairs(ground_truths, predictions) return object_pair_miou(pairs, include_subclasses=include_subclasses) - prediction_np = np.max([pred.value.draw(color=1) for pred in predictions], - axis=0) + prediction_np = np.max( + [pred.value.draw(color=1) for pred in predictions], axis=0 + ) ground_truth_np = np.max( [ground_truth.value.draw(color=1) for ground_truth in ground_truths], - axis=0) + axis=0, + ) if prediction_np.shape != ground_truth_np.shape: raise ValueError( "Prediction and mask must have the same shape." - f" Found {prediction_np.shape}/{ground_truth_np.shape}.") + f" Found {prediction_np.shape}/{ground_truth_np.shape}." + ) return _mask_iou(ground_truth_np, prediction_np) def classification_miou( - ground_truths: List[ClassificationAnnotation], - predictions: List[ClassificationAnnotation]) -> ScalarMetricValue: + ground_truths: List[ClassificationAnnotation], + predictions: List[ClassificationAnnotation], +) -> ScalarMetricValue: """ Computes iou score for all features with the same feature schema id. @@ -168,14 +205,15 @@ def classification_miou( """ if len(predictions) != len(ground_truths) != 1: - return 0. + return 0.0 prediction, ground_truth = predictions[0], ground_truths[0] if type(prediction) != type(ground_truth): raise TypeError( "Classification features must be the same type to compute agreement. " - f"Found `{type(prediction)}` and `{type(ground_truth)}`") + f"Found `{type(prediction)}` and `{type(ground_truth)}`" + ) if isinstance(prediction.value, Text): return text_iou(ground_truth.value, prediction.value) @@ -193,7 +231,8 @@ def radio_iou(ground_truth: Radio, prediction: Radio) -> ScalarMetricValue: """ key = get_identifying_key([prediction.answer], [ground_truth.answer]) return float( - getattr(prediction.answer, key) == getattr(ground_truth.answer, key)) + getattr(prediction.answer, key) == getattr(ground_truth.answer, key) + ) def text_iou(ground_truth: Text, prediction: Text) -> ScalarMetricValue: @@ -203,8 +242,9 @@ def text_iou(ground_truth: Text, prediction: Text) -> ScalarMetricValue: return float(prediction.answer == ground_truth.answer) -def checklist_iou(ground_truth: Checklist, - prediction: Checklist) -> ScalarMetricValue: +def checklist_iou( + ground_truth: Checklist, prediction: Checklist +) -> ScalarMetricValue: """ Calculates agreement between ground truth and predicted checklist items """ @@ -212,13 +252,15 @@ def checklist_iou(ground_truth: Checklist, schema_ids_pred = {getattr(answer, key) for answer in prediction.answer} schema_ids_label = {getattr(answer, key) for answer in ground_truth.answer} return float( - len(schema_ids_label & schema_ids_pred) / - len(schema_ids_label | schema_ids_pred)) + len(schema_ids_label & schema_ids_pred) + / len(schema_ids_label | schema_ids_pred) + ) def _get_vector_pairs( - ground_truths: List[ObjectAnnotation], predictions: List[ObjectAnnotation], - buffer: float + ground_truths: List[ObjectAnnotation], + predictions: List[ObjectAnnotation], + buffer: float, ) -> List[Tuple[ObjectAnnotation, ObjectAnnotation, ScalarMetricValue]]: """ # Get iou score for all pairs of ground truths and predictions @@ -226,14 +268,17 @@ def _get_vector_pairs( pairs = [] for ground_truth, prediction in product(ground_truths, predictions): if isinstance(prediction.value, Geometry) and isinstance( - ground_truth.value, Geometry): + ground_truth.value, Geometry + ): if isinstance(prediction.value, (Line, Point)): - - score = _polygon_iou(prediction.value.shapely.buffer(buffer), - ground_truth.value.shapely.buffer(buffer)) + score = _polygon_iou( + prediction.value.shapely.buffer(buffer), + ground_truth.value.shapely.buffer(buffer), + ) else: - score = _polygon_iou(prediction.value.shapely, - ground_truth.value.shapely) + score = _polygon_iou( + prediction.value.shapely, ground_truth.value.shapely + ) pairs.append((ground_truth, prediction, score)) return pairs @@ -247,9 +292,11 @@ def _get_mask_pairs( pairs = [] for ground_truth, prediction in product(ground_truths, predictions): if isinstance(prediction.value, Mask) and isinstance( - ground_truth.value, Mask): - score = _mask_iou(prediction.value.draw(color=1), - ground_truth.value.draw(color=1)) + ground_truth.value, Mask + ): + score = _mask_iou( + prediction.value.draw(color=1), ground_truth.value.draw(color=1) + ) pairs.append((ground_truth, prediction, score)) return pairs @@ -259,7 +306,7 @@ def _polygon_iou(poly1: Polygon, poly2: Polygon) -> ScalarMetricValue: poly1, poly2 = _ensure_valid_poly(poly1), _ensure_valid_poly(poly2) if poly1.intersects(poly2): return poly1.intersection(poly2).area / poly1.union(poly2).area - return 0. + return 0.0 def _ensure_valid_poly(poly): @@ -286,22 +333,28 @@ def _get_ner_pairs( def _ner_iou(ner1: TextEntity, ner2: TextEntity): """Computes iou between two text entity annotations""" - intersection_start, intersection_end = max(ner1.start, ner2.start), min( - ner1.end, ner2.end) - union_start, union_end = min(ner1.start, - ner2.start), max(ner1.end, ner2.end) - #edge case of only one character in text + intersection_start, intersection_end = ( + max(ner1.start, ner2.start), + min(ner1.end, ner2.end), + ) + union_start, union_end = ( + min(ner1.start, ner2.start), + max(ner1.end, ner2.end), + ) + # edge case of only one character in text if union_start == union_end: return 1 - #if there is no intersection + # if there is no intersection if intersection_start > intersection_end: return 0 return (intersection_end - intersection_start) / (union_end - union_start) -def ner_miou(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation], - include_subclasses: bool) -> Optional[ScalarMetricValue]: +def ner_miou( + ground_truths: List[ObjectAnnotation], + predictions: List[ObjectAnnotation], + include_subclasses: bool, +) -> Optional[ScalarMetricValue]: """ Computes iou score for all features with the same feature schema id. Calculation includes subclassifications. @@ -314,8 +367,8 @@ def ner_miou(ground_truths: List[ObjectAnnotation], If there are no matches then this returns none """ if has_no_matching_annotations(ground_truths, predictions): - return 0. + return 0.0 elif has_no_annotations(ground_truths, predictions): return None pairs = _get_ner_pairs(ground_truths, predictions) - return object_pair_miou(pairs, include_subclasses) \ No newline at end of file + return object_pair_miou(pairs, include_subclasses) diff --git a/libs/labelbox/src/labelbox/data/metrics/iou/iou.py b/libs/labelbox/src/labelbox/data/metrics/iou/iou.py index 357dc5dc9..9b0ce2695 100644 --- a/libs/labelbox/src/labelbox/data/metrics/iou/iou.py +++ b/libs/labelbox/src/labelbox/data/metrics/iou/iou.py @@ -1,19 +1,22 @@ # type: ignore from labelbox.data.annotation_types.metrics.scalar import ScalarMetric from typing import List, Optional, Union -from ...annotation_types import (Label, ObjectAnnotation, - ClassificationAnnotation) +from ...annotation_types import ( + Label, + ObjectAnnotation, + ClassificationAnnotation, +) from ..group import get_feature_pairs from .calculation import feature_miou from .calculation import miou -def miou_metric(ground_truths: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - include_subclasses=False) -> List[ScalarMetric]: +def miou_metric( + ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], + predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], + include_subclasses=False, +) -> List[ScalarMetric]: """ Computes miou between two sets of annotations. These annotations should relate to the same data (image/video). @@ -34,11 +37,11 @@ def miou_metric(ground_truths: List[Union[ObjectAnnotation, return [ScalarMetric(metric_name="custom_iou", value=iou)] -def feature_miou_metric(ground_truths: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - include_subclasses=True) -> List[ScalarMetric]: +def feature_miou_metric( + ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], + predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], + include_subclasses=True, +) -> List[ScalarMetric]: """ Computes the miou for each type of class in the list of annotations. These annotations should relate to the same data (image/video). @@ -56,21 +59,24 @@ def feature_miou_metric(ground_truths: List[Union[ObjectAnnotation, annotation_pairs = get_feature_pairs(predictions, ground_truths) metrics = [] for key in annotation_pairs: - - value = feature_miou(annotation_pairs[key][0], annotation_pairs[key][1], - include_subclasses) + value = feature_miou( + annotation_pairs[key][0], + annotation_pairs[key][1], + include_subclasses, + ) if value is None: continue metrics.append( - ScalarMetric(metric_name="custom_iou", - feature_name=key, - value=value)) + ScalarMetric( + metric_name="custom_iou", feature_name=key, value=value + ) + ) return metrics -def data_row_miou(ground_truth: Label, - prediction: Label, - include_subclasses=False) -> Optional[float]: +def data_row_miou( + ground_truth: Label, prediction: Label, include_subclasses=False +) -> Optional[float]: """ This function is no longer supported. Use miou() for raw values or miou_metric() for the metric @@ -84,5 +90,6 @@ def data_row_miou(ground_truth: Label, float indicating the iou score for this data row. Returns None if there are no annotations in ground_truth or prediction Labels """ - return miou(ground_truth.annotations, prediction.annotations, - include_subclasses) + return miou( + ground_truth.annotations, prediction.annotations, include_subclasses + ) diff --git a/libs/labelbox/src/labelbox/data/mixins.py b/libs/labelbox/src/labelbox/data/mixins.py index d8bc78de0..4440c8a72 100644 --- a/libs/labelbox/src/labelbox/data/mixins.py +++ b/libs/labelbox/src/labelbox/data/mixins.py @@ -2,7 +2,10 @@ from pydantic import BaseModel, field_validator, model_serializer -from labelbox.exceptions import ConfidenceNotSupportedException, CustomMetricsNotSupportedException +from labelbox.exceptions import ( + ConfidenceNotSupportedException, + CustomMetricsNotSupportedException, +) from warnings import warn @@ -20,11 +23,11 @@ def confidence_valid_float(cls, value): class ConfidenceNotSupportedMixin: - def __new__(cls, *args, **kwargs): if "confidence" in kwargs: raise ConfidenceNotSupportedException( - "Confidence is not supported for this annotation type yet") + "Confidence is not supported for this annotation type yet" + ) return super().__new__(cls) @@ -50,9 +53,9 @@ class CustomMetricsMixin(BaseModel): class CustomMetricsNotSupportedMixin: - def __new__(cls, *args, **kwargs): if "custom_metrics" in kwargs: raise CustomMetricsNotSupportedException( - "Custom metrics is not supported for this annotation type yet") + "Custom metrics is not supported for this annotation type yet" + ) return super().__new__(cls) diff --git a/libs/labelbox/src/labelbox/data/ontology.py b/libs/labelbox/src/labelbox/data/ontology.py index f19208873..4d2e66e95 100644 --- a/libs/labelbox/src/labelbox/data/ontology.py +++ b/libs/labelbox/src/labelbox/data/ontology.py @@ -1,13 +1,23 @@ from typing import Dict, List, Tuple, Union from labelbox.schema import ontology -from .annotation_types import (Text, Checklist, Radio, - ClassificationAnnotation, ObjectAnnotation, Mask, - Point, Line, Polygon, Rectangle, TextEntity) +from .annotation_types import ( + Text, + Checklist, + Radio, + ClassificationAnnotation, + ObjectAnnotation, + Mask, + Point, + Line, + Polygon, + Rectangle, + TextEntity, +) def get_feature_schema_lookup( - ontology_builder: ontology.OntologyBuilder + ontology_builder: ontology.OntologyBuilder, ) -> Tuple[Dict[str, str], Dict[str, str]]: tool_lookup = {} classification_lookup = {} @@ -19,11 +29,13 @@ def flatten_classification(classifications): f"feature_schema_id cannot be None for classification `{classification.name}`." ) if isinstance(classification, ontology.Classification): - classification_lookup[ - classification.name] = classification.feature_schema_id + classification_lookup[classification.name] = ( + classification.feature_schema_id + ) elif isinstance(classification, ontology.Option): - classification_lookup[ - classification.value] = classification.feature_schema_id + classification_lookup[classification.value] = ( + classification.feature_schema_id + ) else: raise TypeError( f"Unexpected type found in ontology. `{type(classification)}`" @@ -33,15 +45,18 @@ def flatten_classification(classifications): for tool in ontology_builder.tools: if tool.feature_schema_id is None: raise ValueError( - f"feature_schema_id cannot be None for tool `{tool.name}`.") + f"feature_schema_id cannot be None for tool `{tool.name}`." + ) tool_lookup[tool.name] = tool.feature_schema_id flatten_classification(tool.classifications) flatten_classification(ontology_builder.classifications) return tool_lookup, classification_lookup -def _get_options(annotation: ClassificationAnnotation, - existing_options: List[ontology.Option]): +def _get_options( + annotation: ClassificationAnnotation, + existing_options: List[ontology.Option], +): if isinstance(annotation.value, Radio): answers = [annotation.value.answer] elif isinstance(annotation.value, Text): @@ -63,7 +78,7 @@ def _get_options(annotation: ClassificationAnnotation, def get_classifications( annotations: List[ClassificationAnnotation], - existing_classifications: List[ontology.Classification] + existing_classifications: List[ontology.Classification], ) -> List[ontology.Classification]: existing_classifications = { classification.name: classification @@ -74,37 +89,45 @@ def get_classifications( classification_feature = existing_classifications.get(annotation.name) if classification_feature: classification_feature.options = _get_options( - annotation, classification_feature.options) + annotation, classification_feature.options + ) elif annotation.name not in existing_classifications: existing_classifications[annotation.name] = ontology.Classification( class_type=classification_mapping(annotation), name=annotation.name, - options=_get_options(annotation, [])) + options=_get_options(annotation, []), + ) return list(existing_classifications.values()) def get_tools( - annotations: List[ObjectAnnotation], - existing_tools: List[ontology.Classification]) -> List[ontology.Tool]: + annotations: List[ObjectAnnotation], + existing_tools: List[ontology.Classification], +) -> List[ontology.Tool]: existing_tools = {tool.name: tool for tool in existing_tools} for annotation in annotations: if annotation.name in existing_tools: # We just want to update classifications existing_tools[ - annotation.name].classifications = get_classifications( - annotation.classifications, - existing_tools[annotation.name].classifications) + annotation.name + ].classifications = get_classifications( + annotation.classifications, + existing_tools[annotation.name].classifications, + ) else: existing_tools[annotation.name] = ontology.Tool( tool=tool_mapping(annotation), name=annotation.name, - classifications=get_classifications(annotation.classifications, - [])) + classifications=get_classifications( + annotation.classifications, [] + ), + ) return list(existing_tools.values()) def tool_mapping( - annotation) -> Union[Mask, Polygon, Point, Rectangle, Line, TextEntity]: + annotation, +) -> Union[Mask, Polygon, Point, Rectangle, Line, TextEntity]: tool_types = ontology.Tool.Type mapping = { Mask: tool_types.SEGMENTATION, @@ -122,8 +145,7 @@ def tool_mapping( return result -def classification_mapping( - annotation) -> Union[Text, Checklist, Radio]: +def classification_mapping(annotation) -> Union[Text, Checklist, Radio]: classification_types = ontology.Classification.Type mapping = { Text: classification_types.TEXT, diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/annotation.py b/libs/labelbox/src/labelbox/data/serialization/coco/annotation.py index a0292e537..e387cb7d9 100644 --- a/libs/labelbox/src/labelbox/data/serialization/coco/annotation.py +++ b/libs/labelbox/src/labelbox/data/serialization/coco/annotation.py @@ -8,7 +8,9 @@ from ...annotation_types.metrics.scalar import ScalarMetric from ...annotation_types.video import VideoMaskAnnotation from ...annotation_types.annotation import ObjectAnnotation -from ...annotation_types.classification.classification import ClassificationAnnotation +from ...annotation_types.classification.classification import ( + ClassificationAnnotation, +) import numpy as np @@ -19,8 +21,9 @@ def rle_decoding(rle_arr: List[int], w: int, h: int) -> np.ndarray: indices = [] for idx, cnt in zip(rle_arr[0::2], rle_arr[1::2]): - indices.extend(list(range(idx - 1, - idx + cnt - 1))) # RLE is 1-based index + indices.extend( + list(range(idx - 1, idx + cnt - 1)) + ) # RLE is 1-based index mask = np.zeros(h * w, dtype=np.uint8) mask[indices] = 1 return mask.reshape((w, h)).T @@ -35,16 +38,18 @@ def get_annotation_lookup(annotations): annotation_lookup = defaultdict(list) for annotation in annotations: # Provide a default value of None if the attribute doesn't exist - attribute_value = getattr(annotation, 'image_id', None) or getattr(annotation, 'name', None) + attribute_value = getattr(annotation, "image_id", None) or getattr( + annotation, "name", None + ) annotation_lookup[attribute_value].append(annotation) - return annotation_lookup + return annotation_lookup class SegmentInfo(BaseModel): id: int category_id: int area: Union[float, int] - bbox: Tuple[float, float, float, float] #[x,y,w,h], + bbox: Tuple[float, float, float, float] # [x,y,w,h], iscrowd: int = 0 @@ -62,7 +67,7 @@ class COCOObjectAnnotation(BaseModel): category_id: int segmentation: Union[RLE, List[List[float]]] # [[x1,y1,x2,y2,x3,y3...]] area: float - bbox: Tuple[float, float, float, float] #[x,y,w,h], + bbox: Tuple[float, float, float, float] # [x,y,w,h], iscrowd: int = 0 diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/categories.py b/libs/labelbox/src/labelbox/data/serialization/coco/categories.py index 07ecacb03..60ba30fce 100644 --- a/libs/labelbox/src/labelbox/data/serialization/coco/categories.py +++ b/libs/labelbox/src/labelbox/data/serialization/coco/categories.py @@ -13,4 +13,5 @@ class Categories(BaseModel): def hash_category_name(name: str) -> int: return int.from_bytes( - md5(name.encode('utf-8')).hexdigest().encode('utf-8'), 'little') + md5(name.encode("utf-8")).hexdigest().encode("utf-8"), "little" + ) diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/converter.py b/libs/labelbox/src/labelbox/data/serialization/coco/converter.py index 1f6e8b178..e270b7573 100644 --- a/libs/labelbox/src/labelbox/data/serialization/coco/converter.py +++ b/libs/labelbox/src/labelbox/data/serialization/coco/converter.py @@ -8,8 +8,9 @@ from ...serialization.coco.panoptic_dataset import CocoPanopticDataset -def create_path_if_not_exists(path: Union[Path, str], - ignore_existing_data=False): +def create_path_if_not_exists( + path: Union[Path, str], ignore_existing_data=False +): path = Path(path) if not path.exists(): path.mkdir(parents=True, exist_ok=True) @@ -37,10 +38,12 @@ class COCOConverter: """ @staticmethod - def serialize_instances(labels: LabelCollection, - image_root: Union[Path, str], - ignore_existing_data=False, - max_workers=8) -> Dict[str, Any]: + def serialize_instances( + labels: LabelCollection, + image_root: Union[Path, str], + ignore_existing_data=False, + max_workers=8, + ) -> Dict[str, Any]: """ Convert a Labelbox LabelCollection into an mscoco dataset. This function will only convert masks, polygons, and rectangles. @@ -60,20 +63,23 @@ def serialize_instances(labels: LabelCollection, warnings.warn( "You are currently utilizing COCOconverter for this action, which will be deprecated in a later release.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) image_root = create_path_if_not_exists(image_root, ignore_existing_data) - return CocoInstanceDataset.from_common(labels=labels, - image_root=image_root, - max_workers=max_workers).model_dump() + return CocoInstanceDataset.from_common( + labels=labels, image_root=image_root, max_workers=max_workers + ).model_dump() @staticmethod - def serialize_panoptic(labels: LabelCollection, - image_root: Union[Path, str], - mask_root: Union[Path, str], - all_stuff: bool = False, - ignore_existing_data=False, - max_workers: int = 8) -> Dict[str, Any]: + def serialize_panoptic( + labels: LabelCollection, + image_root: Union[Path, str], + mask_root: Union[Path, str], + all_stuff: bool = False, + ignore_existing_data=False, + max_workers: int = 8, + ) -> Dict[str, Any]: """ Convert a Labelbox LabelCollection into an mscoco dataset. This function will only convert masks, polygons, and rectangles. @@ -96,20 +102,25 @@ def serialize_panoptic(labels: LabelCollection, warnings.warn( "You are currently utilizing COCOconverter for this action, which will be deprecated in a later release.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) image_root = create_path_if_not_exists(image_root, ignore_existing_data) mask_root = create_path_if_not_exists(mask_root, ignore_existing_data) - return CocoPanopticDataset.from_common(labels=labels, - image_root=image_root, - mask_root=mask_root, - all_stuff=all_stuff, - max_workers=max_workers).model_dump() + return CocoPanopticDataset.from_common( + labels=labels, + image_root=image_root, + mask_root=mask_root, + all_stuff=all_stuff, + max_workers=max_workers, + ).model_dump() @staticmethod - def deserialize_panoptic(json_data: Dict[str, Any], image_root: Union[Path, - str], - mask_root: Union[Path, str]) -> LabelGenerator: + def deserialize_panoptic( + json_data: Dict[str, Any], + image_root: Union[Path, str], + mask_root: Union[Path, str], + ) -> LabelGenerator: """ Convert coco panoptic data into the labelbox format (as a LabelGenerator). @@ -124,17 +135,19 @@ def deserialize_panoptic(json_data: Dict[str, Any], image_root: Union[Path, warnings.warn( "You are currently utilizing COCOconverter for this action, which will be deprecated in a later release.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) - image_root = validate_path(image_root, 'image_root') - mask_root = validate_path(mask_root, 'mask_root') + image_root = validate_path(image_root, "image_root") + mask_root = validate_path(mask_root, "mask_root") objs = CocoPanopticDataset(**json_data) gen = objs.to_common(image_root, mask_root) return LabelGenerator(data=gen) @staticmethod - def deserialize_instances(json_data: Dict[str, Any], - image_root: Path) -> LabelGenerator: + def deserialize_instances( + json_data: Dict[str, Any], image_root: Path + ) -> LabelGenerator: """ Convert coco object data into the labelbox format (as a LabelGenerator). @@ -148,9 +161,10 @@ def deserialize_instances(json_data: Dict[str, Any], warnings.warn( "You are currently utilizing COCOconverter for this action, which will be deprecated in a later release.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) - image_root = validate_path(image_root, 'image_root') + image_root = validate_path(image_root, "image_root") objs = CocoInstanceDataset(**json_data) gen = objs.to_common(image_root) return LabelGenerator(data=gen) diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/image.py b/libs/labelbox/src/labelbox/data/serialization/coco/image.py index 71029b936..cef173377 100644 --- a/libs/labelbox/src/labelbox/data/serialization/coco/image.py +++ b/libs/labelbox/src/labelbox/data/serialization/coco/image.py @@ -47,6 +47,6 @@ def id_to_rgb(id: int) -> Tuple[int, int, int]: def rgb_to_id(red: int, green: int, blue: int) -> int: id = blue * 256 * 256 - id += (green * 256) + id += green * 256 id += red return id diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/instance_dataset.py b/libs/labelbox/src/labelbox/data/serialization/coco/instance_dataset.py index 7cade81a1..5241e596f 100644 --- a/libs/labelbox/src/labelbox/data/serialization/coco/instance_dataset.py +++ b/libs/labelbox/src/labelbox/data/serialization/coco/instance_dataset.py @@ -7,17 +7,34 @@ import numpy as np from tqdm import tqdm -from ...annotation_types import ImageData, MaskData, Mask, ObjectAnnotation, Label, Polygon, Point, Rectangle +from ...annotation_types import ( + ImageData, + MaskData, + Mask, + ObjectAnnotation, + Label, + Polygon, + Point, + Rectangle, +) from ...annotation_types.collection import LabelCollection from .categories import Categories, hash_category_name -from .annotation import COCOObjectAnnotation, RLE, get_annotation_lookup, rle_decoding +from .annotation import ( + COCOObjectAnnotation, + RLE, + get_annotation_lookup, + rle_decoding, +) from .image import CocoImage, get_image, get_image_id from pydantic import BaseModel def mask_to_coco_object_annotation( - annotation: ObjectAnnotation, annot_idx: int, image_id: int, - category_id: int) -> Optional[COCOObjectAnnotation]: + annotation: ObjectAnnotation, + annot_idx: int, + image_id: int, + category_id: int, +) -> Optional[COCOObjectAnnotation]: # This is going to fill any holes into the multipolygon # If you need to support holes use the panoptic data format shapely = annotation.value.shapely.simplify(1).buffer(0) @@ -38,12 +55,16 @@ def mask_to_coco_object_annotation( ], area=area, bbox=[xmin, ymin, xmax - xmin, ymax - ymin], - iscrowd=0) + iscrowd=0, + ) -def vector_to_coco_object_annotation(annotation: ObjectAnnotation, - annot_idx: int, image_id: int, - category_id: int) -> COCOObjectAnnotation: +def vector_to_coco_object_annotation( + annotation: ObjectAnnotation, + annot_idx: int, + image_id: int, + category_id: int, +) -> COCOObjectAnnotation: shapely = annotation.value.shapely xmin, ymin, xmax, ymax = shapely.bounds segmentation = [] @@ -52,61 +73,83 @@ def vector_to_coco_object_annotation(annotation: ObjectAnnotation, segmentation.extend([point.x, point.y]) else: box = annotation.value - segmentation.extend([ - box.start.x, box.start.y, box.end.x, box.start.y, box.end.x, - box.end.y, box.start.x, box.end.y - ]) - - return COCOObjectAnnotation(id=annot_idx, - image_id=image_id, - category_id=category_id, - segmentation=[segmentation], - area=shapely.area, - bbox=[xmin, ymin, xmax - xmin, ymax - ymin], - iscrowd=0) - - -def rle_to_common(class_annotations: COCOObjectAnnotation, - class_name: str) -> ObjectAnnotation: - mask = rle_decoding(class_annotations.segmentation.counts, - *class_annotations.segmentation.size[::-1]) - return ObjectAnnotation(name=class_name, - value=Mask(mask=MaskData.from_2D_arr(mask), - color=[1, 1, 1])) - - -def segmentations_to_common(class_annotations: COCOObjectAnnotation, - class_name: str) -> List[ObjectAnnotation]: + segmentation.extend( + [ + box.start.x, + box.start.y, + box.end.x, + box.start.y, + box.end.x, + box.end.y, + box.start.x, + box.end.y, + ] + ) + + return COCOObjectAnnotation( + id=annot_idx, + image_id=image_id, + category_id=category_id, + segmentation=[segmentation], + area=shapely.area, + bbox=[xmin, ymin, xmax - xmin, ymax - ymin], + iscrowd=0, + ) + + +def rle_to_common( + class_annotations: COCOObjectAnnotation, class_name: str +) -> ObjectAnnotation: + mask = rle_decoding( + class_annotations.segmentation.counts, + *class_annotations.segmentation.size[::-1], + ) + return ObjectAnnotation( + name=class_name, + value=Mask(mask=MaskData.from_2D_arr(mask), color=[1, 1, 1]), + ) + + +def segmentations_to_common( + class_annotations: COCOObjectAnnotation, class_name: str +) -> List[ObjectAnnotation]: # Technically it is polygons. But the key in coco is called segmentations.. annotations = [] for points in class_annotations.segmentation: annotations.append( - ObjectAnnotation(name=class_name, - value=Polygon(points=[ - Point(x=points[i], y=points[i + 1]) - for i in range(0, len(points), 2) - ]))) + ObjectAnnotation( + name=class_name, + value=Polygon( + points=[ + Point(x=points[i], y=points[i + 1]) + for i in range(0, len(points), 2) + ] + ), + ) + ) return annotations def object_annotation_to_coco( - annotation: ObjectAnnotation, annot_idx: int, image_id: int, - category_id: int) -> Optional[COCOObjectAnnotation]: + annotation: ObjectAnnotation, + annot_idx: int, + image_id: int, + category_id: int, +) -> Optional[COCOObjectAnnotation]: if isinstance(annotation.value, Mask): - return mask_to_coco_object_annotation(annotation, annot_idx, image_id, - category_id) + return mask_to_coco_object_annotation( + annotation, annot_idx, image_id, category_id + ) elif isinstance(annotation.value, (Polygon, Rectangle)): - return vector_to_coco_object_annotation(annotation, annot_idx, image_id, - category_id) + return vector_to_coco_object_annotation( + annotation, annot_idx, image_id, category_id + ) else: return None def process_label( - label: Label, - idx: int, - image_root: str, - max_annotations_per_image=10000 + label: Label, idx: int, image_root: str, max_annotations_per_image=10000 ) -> Tuple[np.ndarray, List[COCOObjectAnnotation], Dict[str, str]]: annot_idx = idx * max_annotations_per_image image_id = get_image_id(label, idx) @@ -117,9 +160,11 @@ def process_label( for class_name in annotation_lookup: for annotation in annotation_lookup[class_name]: category_id = categories.get(annotation.name) or hash_category_name( - annotation.name) - coco_annotation = object_annotation_to_coco(annotation, annot_idx, - image_id, category_id) + annotation.name + ) + coco_annotation = object_annotation_to_coco( + annotation, annot_idx, image_id, category_id + ) if coco_annotation is not None: coco_annotations.append(coco_annotation) if annotation.name not in categories: @@ -136,10 +181,9 @@ class CocoInstanceDataset(BaseModel): categories: List[Categories] @classmethod - def from_common(cls, - labels: LabelCollection, - image_root: Path, - max_workers=8): + def from_common( + cls, labels: LabelCollection, image_root: Path, max_workers=8 + ): all_coco_annotations = [] categories = {} images = [] @@ -156,7 +200,6 @@ def from_common(cls, future.result() for future in tqdm(as_completed(futures)) ] else: - results = [ process_label(label, idx, image_root) for idx, label in enumerate(labels) @@ -172,18 +215,23 @@ def from_common(cls, for idx, category_id in enumerate(coco_categories.values()) } categories = [ - Categories(id=category_mapping[idx], - name=name, - supercategory='all', - isthing=1) for name, idx in coco_categories.items() + Categories( + id=category_mapping[idx], + name=name, + supercategory="all", + isthing=1, + ) + for name, idx in coco_categories.items() ] for annot in all_coco_annotations: annot.category_id = category_mapping[annot.category_id] - return CocoInstanceDataset(info={'image_root': image_root}, - images=images, - annotations=all_coco_annotations, - categories=categories) + return CocoInstanceDataset( + info={"image_root": image_root}, + images=images, + annotations=all_coco_annotations, + categories=categories, + ) def to_common(self, image_root): category_lookup = { @@ -204,11 +252,15 @@ def to_common(self, image_root): if isinstance(class_annotations.segmentation, RLE): annotations.append( rle_to_common( - class_annotations, category_lookup[ - class_annotations.category_id].name)) + class_annotations, + category_lookup[class_annotations.category_id].name, + ) + ) elif isinstance(class_annotations.segmentation, list): annotations.extend( segmentations_to_common( - class_annotations, category_lookup[ - class_annotations.category_id].name)) + class_annotations, + category_lookup[class_annotations.category_id].name, + ) + ) yield Label(data=data, annotations=annotations) diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/panoptic_dataset.py b/libs/labelbox/src/labelbox/data/serialization/coco/panoptic_dataset.py index 4d6b9e2ef..cbb410548 100644 --- a/libs/labelbox/src/labelbox/data/serialization/coco/panoptic_dataset.py +++ b/libs/labelbox/src/labelbox/data/serialization/coco/panoptic_dataset.py @@ -18,29 +18,36 @@ from pydantic import BaseModel -def vector_to_coco_segment_info(canvas: np.ndarray, - annotation: ObjectAnnotation, - annotation_idx: int, image: CocoImage, - category_id: int): - +def vector_to_coco_segment_info( + canvas: np.ndarray, + annotation: ObjectAnnotation, + annotation_idx: int, + image: CocoImage, + category_id: int, +): shapely = annotation.value.shapely if shapely.is_empty: return xmin, ymin, xmax, ymax = shapely.bounds - canvas = annotation.value.draw(height=image.height, - width=image.width, - canvas=canvas, - color=id_to_rgb(annotation_idx)) - - return SegmentInfo(id=annotation_idx, - category_id=category_id, - area=shapely.area, - bbox=[xmin, ymin, xmax - xmin, ymax - ymin]), canvas - - -def mask_to_coco_segment_info(canvas: np.ndarray, annotation, - annotation_idx: int, category_id): + canvas = annotation.value.draw( + height=image.height, + width=image.width, + canvas=canvas, + color=id_to_rgb(annotation_idx), + ) + + return SegmentInfo( + id=annotation_idx, + category_id=category_id, + area=shapely.area, + bbox=[xmin, ymin, xmax - xmin, ymax - ymin], + ), canvas + + +def mask_to_coco_segment_info( + canvas: np.ndarray, annotation, annotation_idx: int, category_id +): color = id_to_rgb(annotation_idx) mask = annotation.value.draw(color=color) shapely = annotation.value.shapely @@ -49,17 +56,17 @@ def mask_to_coco_segment_info(canvas: np.ndarray, annotation, xmin, ymin, xmax, ymax = shapely.bounds canvas = np.where(canvas == (0, 0, 0), mask, canvas) - return SegmentInfo(id=annotation_idx, - category_id=category_id, - area=shapely.area, - bbox=[xmin, ymin, xmax - xmin, ymax - ymin]), canvas + return SegmentInfo( + id=annotation_idx, + category_id=category_id, + area=shapely.area, + bbox=[xmin, ymin, xmax - xmin, ymax - ymin], + ), canvas -def process_label(label: Label, - idx: Union[int, str], - image_root, - mask_root, - all_stuff=False): +def process_label( + label: Label, idx: Union[int, str], image_root, mask_root, all_stuff=False +): """ Masks become stuff Polygon and rectangle become thing @@ -78,8 +85,11 @@ def process_label(label: Label, categories[annotation.name] = hash_category_name(annotation.name) if isinstance(annotation.value, Mask): coco_segment_info = mask_to_coco_segment_info( - canvas, annotation, class_idx + 1, - categories[annotation.name]) + canvas, + annotation, + class_idx + 1, + categories[annotation.name], + ) if coco_segment_info is None: # Filter out empty masks @@ -96,7 +106,8 @@ def process_label(label: Label, annotation_idx=(class_idx if all_stuff else annotation_idx) + 1, image=image, - category_id=categories[annotation.name]) + category_id=categories[annotation.name], + ) if coco_vector_info is None: # Filter out empty annotations @@ -106,13 +117,19 @@ def process_label(label: Label, segments.append(segment) is_thing[annotation.name] = 1 - int(all_stuff) - mask_file = str(image.file_name).replace('.jpg', '.png') + mask_file = str(image.file_name).replace(".jpg", ".png") mask_file = Path(mask_root, mask_file) Image.fromarray(canvas.astype(np.uint8)).save(mask_file) - return image, PanopticAnnotation( - image_id=image_id, - file_name=Path(mask_file.name), - segments_info=segments), categories, is_thing + return ( + image, + PanopticAnnotation( + image_id=image_id, + file_name=Path(mask_file.name), + segments_info=segments, + ), + categories, + is_thing, + ) class CocoPanopticDataset(BaseModel): @@ -122,12 +139,14 @@ class CocoPanopticDataset(BaseModel): categories: List[Categories] @classmethod - def from_common(cls, - labels: LabelCollection, - image_root, - mask_root, - all_stuff, - max_workers=8): + def from_common( + cls, + labels: LabelCollection, + image_root, + mask_root, + all_stuff, + max_workers=8, + ): all_coco_annotations = [] coco_categories = {} coco_things = {} @@ -136,8 +155,15 @@ def from_common(cls, if max_workers: with ProcessPoolExecutor(max_workers=max_workers) as exc: futures = [ - exc.submit(process_label, label, idx, image_root, mask_root, - all_stuff) for idx, label in enumerate(labels) + exc.submit( + process_label, + label, + idx, + image_root, + mask_root, + all_stuff, + ) + for idx, label in enumerate(labels) ] results = [ future.result() for future in tqdm(as_completed(futures)) @@ -159,10 +185,12 @@ def from_common(cls, for idx, category_id in enumerate(coco_categories.values()) } categories = [ - Categories(id=category_mapping[idx], - name=name, - supercategory='all', - isthing=coco_things.get(name, 1)) + Categories( + id=category_mapping[idx], + name=name, + supercategory="all", + isthing=coco_things.get(name, 1), + ) for name, idx in coco_categories.items() ] @@ -170,13 +198,12 @@ def from_common(cls, for segment in annot.segments_info: segment.category_id = category_mapping[segment.category_id] - return CocoPanopticDataset(info={ - 'image_root': image_root, - 'mask_root': mask_root - }, - images=images, - annotations=all_coco_annotations, - categories=categories) + return CocoPanopticDataset( + info={"image_root": image_root, "mask_root": mask_root}, + images=images, + annotations=all_coco_annotations, + categories=categories, + ) def to_common(self, image_root: Path, mask_root: Path): category_lookup = { @@ -194,20 +221,22 @@ def to_common(self, image_root: Path, mask_root: Path): raise ValueError( f"Cannot find file {im_path}. Make sure `image_root` is set properly" ) - if not str(annotation.file_name).endswith('.png'): + if not str(annotation.file_name).endswith(".png"): raise ValueError( f"COCO masks must be stored as png files and their extension must be `.png`. Found {annotation.file_name}" ) mask = MaskData( - file_path=str(Path(mask_root, annotation.file_name))) + file_path=str(Path(mask_root, annotation.file_name)) + ) for segmentation in annotation.segments_info: category = category_lookup[segmentation.category_id] annotations.append( - ObjectAnnotation(name=category.name, - value=Mask(mask=mask, - color=id_to_rgb( - segmentation.id)))) + ObjectAnnotation( + name=category.name, + value=Mask(mask=mask, color=id_to_rgb(segmentation.id)), + ) + ) data = ImageData(file_path=str(im_path)) yield Label(data=data, annotations=annotations) del annotation_lookup[image.id] diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/path.py b/libs/labelbox/src/labelbox/data/serialization/coco/path.py index 8f6786655..c3be84f31 100644 --- a/libs/labelbox/src/labelbox/data/serialization/coco/path.py +++ b/libs/labelbox/src/labelbox/data/serialization/coco/path.py @@ -1,8 +1,8 @@ from pathlib import Path from pydantic import BaseModel, model_serializer -class PathSerializerMixin(BaseModel): +class PathSerializerMixin(BaseModel): @model_serializer(mode="wrap") def serialize_model(self, handler): res = handler(self) diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py index 602fa7628..8770222b9 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py @@ -9,20 +9,20 @@ subclass_registry = {} + class _SubclassRegistryBase(BaseModel): - model_config = ConfigDict(extra="allow") - + def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if cls.__name__ != "NDAnnotation": with threading.Lock(): - subclass_registry[cls.__name__] = cls + subclass_registry[cls.__name__] = cls + class DataRow(_CamelCaseMixin): id: Optional[str] = None global_key: Optional[str] = None - @model_validator(mode="after") def must_set_one(self): @@ -45,6 +45,8 @@ class NDAnnotation(NDJsonBase): @model_validator(mode="after") def must_set_one(self): - if (not hasattr(self, "schema_id") or self.schema_id is None) and (not hasattr(self, "name") or self.name is None): + if (not hasattr(self, "schema_id") or self.schema_id is None) and ( + not hasattr(self, "name") or self.name is None + ): raise ValueError("Schema id or name are not set. Set either one.") return self diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py index e655e9f36..f4bc7e528 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py @@ -1,15 +1,33 @@ from typing import Any, Dict, List, Union, Optional -from labelbox.data.mixins import ConfidenceMixin, CustomMetric, CustomMetricsMixin +from labelbox.data.mixins import ( + ConfidenceMixin, + CustomMetric, + CustomMetricsMixin, +) from labelbox.data.serialization.ndjson.base import DataRow, NDAnnotation from ...annotation_types.annotation import ClassificationAnnotation from ...annotation_types.video import VideoClassificationAnnotation -from ...annotation_types.llm_prompt_response.prompt import PromptClassificationAnnotation, PromptText -from ...annotation_types.classification.classification import ClassificationAnswer, Text, Checklist, Radio +from ...annotation_types.llm_prompt_response.prompt import ( + PromptClassificationAnnotation, + PromptText, +) +from ...annotation_types.classification.classification import ( + ClassificationAnswer, + Text, + Checklist, + Radio, +) from ...annotation_types.types import Cuid from ...annotation_types.data import TextData, VideoData, ImageData -from pydantic import model_validator, Field, BaseModel, ConfigDict, model_serializer +from pydantic import ( + model_validator, + Field, + BaseModel, + ConfigDict, + model_serializer, +) from pydantic.alias_generators import to_camel from .base import _SubclassRegistryBase @@ -17,24 +35,26 @@ class NDAnswer(ConfidenceMixin, CustomMetricsMixin): name: Optional[str] = None schema_id: Optional[Cuid] = None - classifications: Optional[List['NDSubclassificationType']] = None - model_config = ConfigDict(populate_by_name = True, alias_generator = to_camel) + classifications: Optional[List["NDSubclassificationType"]] = None + model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel) @model_validator(mode="after") def must_set_one(self): - if (not hasattr(self, "schema_id") or self.schema_id is None) and (not hasattr(self, "name") or self.name is None): + if (not hasattr(self, "schema_id") or self.schema_id is None) and ( + not hasattr(self, "name") or self.name is None + ): raise ValueError("Schema id or name are not set. Set either one.") return self @model_serializer(mode="wrap") def serialize_model(self, handler): res = handler(self) - if 'name' in res and res['name'] is None: - res.pop('name') - if 'schemaId' in res and res['schemaId'] is None: - res.pop('schemaId') + if "name" in res and res["name"] is None: + res.pop("name") + if "schemaId" in res and res["schemaId"] is None: + res.pop("schemaId") if self.classifications: - res['classifications'] = [ + res["classifications"] = [ c.model_dump(exclude_none=True) for c in self.classifications ] return res @@ -54,7 +74,7 @@ def serialize_model(self, handler): res = handler(self) # This means these are no video frames .. if self.frames is None: - res.pop('frames') + res.pop("frames") return res @@ -62,13 +82,16 @@ class NDTextSubclass(NDAnswer): answer: str def to_common(self) -> Text: - return Text(answer=self.answer, - confidence=self.confidence, - custom_metrics=self.custom_metrics) + return Text( + answer=self.answer, + confidence=self.confidence, + custom_metrics=self.custom_metrics, + ) @classmethod - def from_common(cls, text: Text, name: str, - feature_schema_id: Cuid) -> "NDTextSubclass": + def from_common( + cls, text: Text, name: str, feature_schema_id: Cuid + ) -> "NDTextSubclass": return cls( answer=text.answer, name=name, @@ -79,41 +102,56 @@ def from_common(cls, text: Text, name: str, class NDChecklistSubclass(NDAnswer): - answer: List[NDAnswer] = Field(..., validation_alias='answers') + answer: List[NDAnswer] = Field(..., validation_alias="answers") def to_common(self) -> Checklist: - - return Checklist(answer=[ - ClassificationAnswer(name=answer.name, - feature_schema_id=answer.schema_id, - confidence=answer.confidence, - classifications=[ - NDSubclassification.to_common(annot) - for annot in answer.classifications - ] if answer.classifications else None, - custom_metrics=answer.custom_metrics) - for answer in self.answer - ]) + return Checklist( + answer=[ + ClassificationAnswer( + name=answer.name, + feature_schema_id=answer.schema_id, + confidence=answer.confidence, + classifications=[ + NDSubclassification.to_common(annot) + for annot in answer.classifications + ] + if answer.classifications + else None, + custom_metrics=answer.custom_metrics, + ) + for answer in self.answer + ] + ) @classmethod - def from_common(cls, checklist: Checklist, name: str, - feature_schema_id: Cuid) -> "NDChecklistSubclass": - return cls(answer=[ - NDAnswer(name=answer.name, - schema_id=answer.feature_schema_id, - confidence=answer.confidence, - classifications=[NDSubclassification.from_common(annot) for annot in answer.classifications] if answer.classifications else None, - custom_metrics=answer.custom_metrics) - for answer in checklist.answer - ], - name=name, - schema_id=feature_schema_id) + def from_common( + cls, checklist: Checklist, name: str, feature_schema_id: Cuid + ) -> "NDChecklistSubclass": + return cls( + answer=[ + NDAnswer( + name=answer.name, + schema_id=answer.feature_schema_id, + confidence=answer.confidence, + classifications=[ + NDSubclassification.from_common(annot) + for annot in answer.classifications + ] + if answer.classifications + else None, + custom_metrics=answer.custom_metrics, + ) + for answer in checklist.answer + ], + name=name, + schema_id=feature_schema_id, + ) @model_serializer(mode="wrap") def serialize_model(self, handler): res = handler(self) - if 'answers' in res: - res['answer'] = res['answers'] + if "answers" in res: + res["answer"] = res["answers"] del res["answers"] return res @@ -122,42 +160,57 @@ class NDRadioSubclass(NDAnswer): answer: NDAnswer def to_common(self) -> Radio: - return Radio(answer=ClassificationAnswer( - name=self.answer.name, - feature_schema_id=self.answer.schema_id, - confidence=self.answer.confidence, - classifications=[ - NDSubclassification.to_common(annot) - for annot in self.answer.classifications - ] if self.answer.classifications else None, - custom_metrics=self.answer.custom_metrics)) + return Radio( + answer=ClassificationAnswer( + name=self.answer.name, + feature_schema_id=self.answer.schema_id, + confidence=self.answer.confidence, + classifications=[ + NDSubclassification.to_common(annot) + for annot in self.answer.classifications + ] + if self.answer.classifications + else None, + custom_metrics=self.answer.custom_metrics, + ) + ) @classmethod - def from_common(cls, radio: Radio, name: str, - feature_schema_id: Cuid) -> "NDRadioSubclass": - return cls(answer=NDAnswer(name=radio.answer.name, - schema_id=radio.answer.feature_schema_id, - confidence=radio.answer.confidence, - classifications=[ - NDSubclassification.from_common(annot) - for annot in radio.answer.classifications - ] if radio.answer.classifications else None, - custom_metrics=radio.answer.custom_metrics), - name=name, - schema_id=feature_schema_id) + def from_common( + cls, radio: Radio, name: str, feature_schema_id: Cuid + ) -> "NDRadioSubclass": + return cls( + answer=NDAnswer( + name=radio.answer.name, + schema_id=radio.answer.feature_schema_id, + confidence=radio.answer.confidence, + classifications=[ + NDSubclassification.from_common(annot) + for annot in radio.answer.classifications + ] + if radio.answer.classifications + else None, + custom_metrics=radio.answer.custom_metrics, + ), + name=name, + schema_id=feature_schema_id, + ) class NDPromptTextSubclass(NDAnswer): answer: str def to_common(self) -> PromptText: - return PromptText(answer=self.answer, - confidence=self.confidence, - custom_metrics=self.custom_metrics) + return PromptText( + answer=self.answer, + confidence=self.confidence, + custom_metrics=self.custom_metrics, + ) @classmethod - def from_common(cls, prompt_text: PromptText, name: str, - feature_schema_id: Cuid) -> "NDPromptTextSubclass": + def from_common( + cls, prompt_text: PromptText, name: str, feature_schema_id: Cuid + ) -> "NDPromptTextSubclass": return cls( answer=prompt_text.answer, name=name, @@ -171,17 +224,18 @@ def from_common(cls, prompt_text: PromptText, name: str, class NDText(NDAnnotation, NDTextSubclass, _SubclassRegistryBase): - @classmethod - def from_common(cls, - uuid: str, - text: Text, - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[TextData, ImageData], - message_id: str, - confidence: Optional[float] = None) -> "NDText": + def from_common( + cls, + uuid: str, + text: Text, + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + data: Union[TextData, ImageData], + message_id: str, + confidence: Optional[float] = None, + ) -> "NDText": return cls( answer=text.answer, data_row=DataRow(id=data.uid, global_key=data.global_key), @@ -194,8 +248,9 @@ def from_common(cls, ) -class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported, _SubclassRegistryBase): - +class NDChecklist( + NDAnnotation, NDChecklistSubclass, VideoSupported, _SubclassRegistryBase +): @model_serializer(mode="wrap") def serialize_model(self, handler): res = handler(self) @@ -205,40 +260,46 @@ def serialize_model(self, handler): @classmethod def from_common( - cls, - uuid: str, - checklist: Checklist, - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[VideoData, TextData, ImageData], - message_id: str, - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None + cls, + uuid: str, + checklist: Checklist, + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + data: Union[VideoData, TextData, ImageData], + message_id: str, + confidence: Optional[float] = None, + custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDChecklist": + return cls( + answer=[ + NDAnswer( + name=answer.name, + schema_id=answer.feature_schema_id, + confidence=answer.confidence, + classifications=[ + NDSubclassification.from_common(annot) + for annot in answer.classifications + ] + if answer.classifications + else None, + custom_metrics=answer.custom_metrics, + ) + for answer in checklist.answer + ], + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + frames=extra.get("frames"), + message_id=message_id, + confidence=confidence, + ) - return cls(answer=[ - NDAnswer(name=answer.name, - schema_id=answer.feature_schema_id, - confidence=answer.confidence, - classifications=[ - NDSubclassification.from_common(annot) - for annot in answer.classifications - ] if answer.classifications else None, - custom_metrics=answer.custom_metrics) - for answer in checklist.answer - ], - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - frames=extra.get('frames'), - message_id=message_id, - confidence=confidence) - - -class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported, _SubclassRegistryBase): +class NDRadio( + NDAnnotation, NDRadioSubclass, VideoSupported, _SubclassRegistryBase +): @classmethod def from_common( cls, @@ -251,32 +312,37 @@ def from_common( message_id: str, confidence: Optional[float] = None, ) -> "NDRadio": - return cls(answer=NDAnswer(name=radio.answer.name, - schema_id=radio.answer.feature_schema_id, - confidence=radio.answer.confidence, - classifications=[ - NDSubclassification.from_common(annot) - for annot in radio.answer.classifications - ] if radio.answer.classifications else None, - custom_metrics=radio.answer.custom_metrics), - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - frames=extra.get('frames'), - message_id=message_id, - confidence=confidence) - + return cls( + answer=NDAnswer( + name=radio.answer.name, + schema_id=radio.answer.feature_schema_id, + confidence=radio.answer.confidence, + classifications=[ + NDSubclassification.from_common(annot) + for annot in radio.answer.classifications + ] + if radio.answer.classifications + else None, + custom_metrics=radio.answer.custom_metrics, + ), + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + frames=extra.get("frames"), + message_id=message_id, + confidence=confidence, + ) + @model_serializer(mode="wrap") def serialize_model(self, handler): res = handler(self) if "classifications" in res and res["classifications"] == []: del res["classifications"] return res - - + + class NDPromptText(NDAnnotation, NDPromptTextSubclass, _SubclassRegistryBase): - @classmethod def from_common( cls, @@ -285,7 +351,7 @@ def from_common( name, data: Dict, feature_schema_id: Cuid, - confidence: Optional[float] = None + confidence: Optional[float] = None, ) -> "NDPromptText": return cls( answer=text.answer, @@ -294,11 +360,11 @@ def from_common( schema_id=feature_schema_id, uuid=uuid, confidence=text.confidence, - custom_metrics=text.custom_metrics) + custom_metrics=text.custom_metrics, + ) class NDSubclassification: - @classmethod def from_common( cls, annotation: ClassificationAnnotation @@ -308,19 +374,23 @@ def from_common( raise TypeError( f"Unable to convert object to MAL format. `{type(annotation.value)}`" ) - return classify_obj.from_common(annotation.value, annotation.name, - annotation.feature_schema_id) + return classify_obj.from_common( + annotation.value, annotation.name, annotation.feature_schema_id + ) @staticmethod def to_common( - annotation: "NDClassificationType") -> ClassificationAnnotation: - return ClassificationAnnotation(value=annotation.to_common(), - name=annotation.name, - feature_schema_id=annotation.schema_id) + annotation: "NDClassificationType", + ) -> ClassificationAnnotation: + return ClassificationAnnotation( + value=annotation.to_common(), + name=annotation.name, + feature_schema_id=annotation.schema_id, + ) @staticmethod def lookup_subclassification( - annotation: ClassificationAnnotation + annotation: ClassificationAnnotation, ) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]: return { Text: NDTextSubclass, @@ -330,69 +400,76 @@ def lookup_subclassification( class NDClassification: - @staticmethod def to_common( - annotation: "NDClassificationType" + annotation: "NDClassificationType", ) -> Union[ClassificationAnnotation, VideoClassificationAnnotation]: common = ClassificationAnnotation( value=annotation.to_common(), name=annotation.name, feature_schema_id=annotation.schema_id, - extra={'uuid': annotation.uuid}, + extra={"uuid": annotation.uuid}, message_id=annotation.message_id, confidence=annotation.confidence, ) - if getattr(annotation, 'frames', None) is None: + if getattr(annotation, "frames", None) is None: return [common] results = [] for frame in annotation.frames: for idx in range(frame.start, frame.end + 1, 1): results.append( - VideoClassificationAnnotation(frame=idx, **common.model_dump(exclude_none=True))) + VideoClassificationAnnotation( + frame=idx, **common.model_dump(exclude_none=True) + ) + ) return results @classmethod def from_common( - cls, annotation: Union[ClassificationAnnotation, - VideoClassificationAnnotation], - data: Union[VideoData, TextData, ImageData] + cls, + annotation: Union[ + ClassificationAnnotation, VideoClassificationAnnotation + ], + data: Union[VideoData, TextData, ImageData], ) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]: classify_obj = cls.lookup_classification(annotation) if classify_obj is None: raise TypeError( f"Unable to convert object to MAL format. `{type(annotation.value)}`" ) - return classify_obj.from_common(str(annotation._uuid), annotation.value, - annotation.name, - annotation.feature_schema_id, - annotation.extra, data, - annotation.message_id, - annotation.confidence) + return classify_obj.from_common( + str(annotation._uuid), + annotation.value, + annotation.name, + annotation.feature_schema_id, + annotation.extra, + data, + annotation.message_id, + annotation.confidence, + ) @staticmethod def lookup_classification( - annotation: Union[ClassificationAnnotation, - VideoClassificationAnnotation] + annotation: Union[ + ClassificationAnnotation, VideoClassificationAnnotation + ], ) -> Union[NDText, NDChecklist, NDRadio]: - return { - Text: NDText, - Checklist: NDChecklist, - Radio: NDRadio - }.get(type(annotation.value)) + return {Text: NDText, Checklist: NDChecklist, Radio: NDRadio}.get( + type(annotation.value) + ) -class NDPromptClassification: +class NDPromptClassification: @staticmethod def to_common( - annotation: "NDPromptClassificationType" + annotation: "NDPromptClassificationType", ) -> Union[PromptClassificationAnnotation]: common = PromptClassificationAnnotation( value=annotation, name=annotation.name, feature_schema_id=annotation.schema_id, - extra={'uuid': annotation.uuid}, + extra={"uuid": annotation.uuid}, confidence=annotation.confidence, ) @@ -400,20 +477,25 @@ def to_common( @classmethod def from_common( - cls, annotation: Union[PromptClassificationAnnotation], - data: Union[VideoData, TextData, ImageData] + cls, + annotation: Union[PromptClassificationAnnotation], + data: Union[VideoData, TextData, ImageData], ) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]: - return NDPromptText.from_common(str(annotation._uuid), annotation.value, - annotation.name, - data, - annotation.feature_schema_id, - annotation.confidence) + return NDPromptText.from_common( + str(annotation._uuid), + annotation.value, + annotation.name, + data, + annotation.feature_schema_id, + annotation.confidence, + ) # Make sure to keep NDChecklistSubclass prior to NDRadioSubclass in the list, # otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used -NDSubclassificationType = Union[NDChecklistSubclass, NDRadioSubclass, - NDTextSubclass] +NDSubclassificationType = Union[ + NDChecklistSubclass, NDRadioSubclass, NDTextSubclass +] NDAnswer.model_rebuild() NDChecklistSubclass.model_rebuild() @@ -427,4 +509,4 @@ def from_common( # Make sure to keep NDChecklist prior to NDRadio in the list, # otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used NDClassificationType = Union[NDChecklist, NDRadio, NDText] -NDPromptClassificationType = Union[NDPromptText] \ No newline at end of file +NDPromptClassificationType = Union[NDPromptText] diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py index a38247271..01ab8454a 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py @@ -6,9 +6,11 @@ from labelbox.data.annotation_types.annotation import ObjectAnnotation from labelbox.data.annotation_types.classification.classification import ( - ClassificationAnnotation,) + ClassificationAnnotation, +) from labelbox.data.annotation_types.metrics.confusion_matrix import ( - ConfusionMatrixMetric,) + ConfusionMatrixMetric, +) from labelbox.data.annotation_types.metrics.scalar import ScalarMetric from labelbox.data.annotation_types.video import VideoMaskAnnotation @@ -24,7 +26,6 @@ class NDJsonConverter: - @staticmethod def deserialize(json_data: Iterable[Dict[str, Any]]) -> LabelGenerator: """ @@ -41,7 +42,8 @@ def deserialize(json_data: Iterable[Dict[str, Any]]) -> LabelGenerator: @staticmethod def serialize( - labels: LabelCollection) -> Generator[Dict[str, Any], None, None]: + labels: LabelCollection, + ) -> Generator[Dict[str, Any], None, None]: """ Converts a labelbox common object to the labelbox ndjson format (prediction import format) @@ -56,8 +58,9 @@ def serialize( """ used_uuids: Set[uuid.UUID] = set() - relationship_uuids: Dict[uuid.UUID, - Deque[uuid.UUID]] = defaultdict(deque) + relationship_uuids: Dict[uuid.UUID, Deque[uuid.UUID]] = defaultdict( + deque + ) # UUIDs are private properties used to enhance UX when defining relationships. # They are created for all annotations, but only utilized for relationships. @@ -66,15 +69,17 @@ def serialize( # For relationship annotations, during first pass, we update the UUIDs of the source and target annotations. # During the second pass, we update the UUIDs of the annotations referenced by the relationship annotations. for label in labels: - uuid_safe_annotations: List[Union[ - ClassificationAnnotation, - ObjectAnnotation, - VideoMaskAnnotation, - ScalarMetric, - ConfusionMatrixMetric, - RelationshipAnnotation, - MessageEvaluationTaskAnnotation, - ]] = [] + uuid_safe_annotations: List[ + Union[ + ClassificationAnnotation, + ObjectAnnotation, + VideoMaskAnnotation, + ScalarMetric, + ConfusionMatrixMetric, + RelationshipAnnotation, + MessageEvaluationTaskAnnotation, + ] + ] = [] # First pass to get all RelationshipAnnotaitons # and update the UUIDs of the source and target annotations for annotation in label.annotations: @@ -83,9 +88,11 @@ def serialize( new_source_uuid = uuid.uuid4() new_target_uuid = uuid.uuid4() relationship_uuids[annotation.value.source._uuid].append( - new_source_uuid) + new_source_uuid + ) relationship_uuids[annotation.value.target._uuid].append( - new_target_uuid) + new_target_uuid + ) annotation.value.source._uuid = new_source_uuid annotation.value.target._uuid = new_target_uuid if annotation._uuid in used_uuids: @@ -94,8 +101,9 @@ def serialize( uuid_safe_annotations.append(annotation) # Second pass to update UUIDs for annotations referenced by RelationshipAnnotations for annotation in label.annotations: - if (not isinstance(annotation, RelationshipAnnotation) and - hasattr(annotation, "_uuid")): + if not isinstance( + annotation, RelationshipAnnotation + ) and hasattr(annotation, "_uuid"): annotation = copy.deepcopy(annotation) next_uuids = relationship_uuids[annotation._uuid] if len(next_uuids) > 0: @@ -119,6 +127,6 @@ def serialize( for k, v in list(res.items()): if k in IGNORE_IF_NONE and v is None: del res[k] - if getattr(label, 'is_benchmark_reference'): - res['isBenchmarkReferenceLabel'] = True + if getattr(label, "is_benchmark_reference"): + res["isBenchmarkReferenceLabel"] = True yield res diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py index b9e9f2456..18134a228 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py @@ -3,9 +3,15 @@ from typing import Dict, Generator, List, Tuple, Union from collections import defaultdict -from ...annotation_types.annotation import ClassificationAnnotation, ObjectAnnotation +from ...annotation_types.annotation import ( + ClassificationAnnotation, + ObjectAnnotation, +) from ...annotation_types.relationship import RelationshipAnnotation -from ...annotation_types.video import DICOMObjectAnnotation, VideoClassificationAnnotation +from ...annotation_types.video import ( + DICOMObjectAnnotation, + VideoClassificationAnnotation, +) from ...annotation_types.video import VideoObjectAnnotation, VideoMaskAnnotation from ...annotation_types.collection import LabelCollection, LabelGenerator from ...annotation_types.data import DicomData, ImageData, TextData, VideoData @@ -13,12 +19,29 @@ from ...annotation_types.label import Label from ...annotation_types.ner import TextEntity, ConversationEntity from ...annotation_types.metrics import ScalarMetric, ConfusionMatrixMetric -from ...annotation_types.llm_prompt_response.prompt import PromptClassificationAnnotation +from ...annotation_types.llm_prompt_response.prompt import ( + PromptClassificationAnnotation, +) from ...annotation_types.mmc import MessageEvaluationTaskAnnotation from .metric import NDScalarMetric, NDMetricAnnotation, NDConfusionMatrixMetric -from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass, NDPromptClassification, NDPromptClassificationType, NDPromptText -from .objects import NDObject, NDObjectType, NDSegments, NDDicomSegments, NDVideoMasks, NDDicomMasks +from .classification import ( + NDChecklistSubclass, + NDClassification, + NDClassificationType, + NDRadioSubclass, + NDPromptClassification, + NDPromptClassificationType, + NDPromptText, +) +from .objects import ( + NDObject, + NDObjectType, + NDSegments, + NDDicomSegments, + NDVideoMasks, + NDDicomMasks, +) from .mmc import NDMessageTask from .relationship import NDRelationship from .base import DataRow @@ -27,19 +50,29 @@ from pydantic_core import PydanticUndefined from contextlib import suppress -AnnotationType = Union[NDObjectType, NDClassificationType, NDPromptClassificationType, - NDConfusionMatrixMetric, NDScalarMetric, NDDicomSegments, - NDSegments, NDDicomMasks, NDVideoMasks, NDRelationship, - NDPromptText, NDMessageTask] +AnnotationType = Union[ + NDObjectType, + NDClassificationType, + NDPromptClassificationType, + NDConfusionMatrixMetric, + NDScalarMetric, + NDDicomSegments, + NDSegments, + NDDicomMasks, + NDVideoMasks, + NDRelationship, + NDPromptText, + NDMessageTask, +] class NDLabel(BaseModel): annotations: List[_SubclassRegistryBase] - + def __init__(self, **kwargs): # NOTE: Deserialization of subclasses in pydantic is difficult, see here https://blog.devgenius.io/deserialize-child-classes-with-pydantic-that-gonna-work-784230e1cf83 # Below implements the subclass registry as mentioned in the article. The python dicts we pass in can be missing certain fields - # we essentially have to infer the type against all sub classes that have the _SubclasssRegistryBase inheritance. + # we essentially have to infer the type against all sub classes that have the _SubclasssRegistryBase inheritance. # It works by checking if the keys of our annotations we are missing in matches any required subclass. # More keys are prioritized over less keys (closer match). This is used when importing json to our base models not a lot of customer workflows # depend on this method but this works for all our existing tests with the bonus of added validation. (no subclass found it throws an error) @@ -49,46 +82,64 @@ def __init__(self, **kwargs): item_annotation_keys = annotation.keys() key_subclass_combos = defaultdict(list) for subclass in subclass_registry.values(): - # Get all required keys from subclass annotation_keys = [] for k, field in subclass.model_fields.items(): if field.default == PydanticUndefined and k != "uuid": - if hasattr(field, "alias") and field.alias in item_annotation_keys: + if ( + hasattr(field, "alias") + and field.alias in item_annotation_keys + ): annotation_keys.append(field.alias) - elif hasattr(field, "validation_alias") and field.validation_alias in item_annotation_keys: + elif ( + hasattr(field, "validation_alias") + and field.validation_alias + in item_annotation_keys + ): annotation_keys.append(field.validation_alias) else: annotation_keys.append(k) - + key_subclass_combos[subclass].extend(annotation_keys) - + # Sort by subclass that has the most keys i.e. the one with the most keys that matches is most likely our subclass - key_subclass_combos = dict(sorted(key_subclass_combos.items(), key = lambda x : len(x[1]), reverse=True)) + key_subclass_combos = dict( + sorted( + key_subclass_combos.items(), + key=lambda x: len(x[1]), + reverse=True, + ) + ) for subclass, key_subclass_combo in key_subclass_combos.items(): # Choose the keys from our dict we supplied that matches the required keys of a subclass - check_required_keys = all(key in list(item_annotation_keys) for key in key_subclass_combo) + check_required_keys = all( + key in list(item_annotation_keys) + for key in key_subclass_combo + ) if check_required_keys: # Keep trying subclasses until we find one that has valid values (does not throw an validation error) with suppress(ValidationError): annotation = subclass(**annotation) break if isinstance(annotation, dict): - raise ValueError(f"Could not find subclass for fields: {item_annotation_keys}") - + raise ValueError( + f"Could not find subclass for fields: {item_annotation_keys}" + ) + kwargs["annotations"][index] = annotation super().__init__(**kwargs) - class _Relationship(BaseModel): """This object holds information about the relationship""" + ndjson: NDRelationship source: str target: str class _AnnotationGroup(BaseModel): """Stores all the annotations and relationships per datarow""" + data_row: DataRow = None ndjson_annotations: Dict[str, AnnotationType] = {} relationships: List["NDLabel._Relationship"] = [] @@ -97,7 +148,10 @@ def to_common(self) -> LabelGenerator: annotation_groups = defaultdict(NDLabel._AnnotationGroup) for ndjson_annotation in self.annotations: - key = ndjson_annotation.data_row.id or ndjson_annotation.data_row.global_key + key = ( + ndjson_annotation.data_row.id + or ndjson_annotation.data_row.global_key + ) group = annotation_groups[key] if isinstance(ndjson_annotation, NDRelationship): @@ -105,7 +159,9 @@ def to_common(self) -> LabelGenerator: NDLabel._Relationship( ndjson=ndjson_annotation, source=ndjson_annotation.relationship.source, - target=ndjson_annotation.relationship.target)) + target=ndjson_annotation.relationship.target, + ) + ) else: # if this is the first object in this group, we # take note of the DataRow this group belongs to @@ -117,17 +173,22 @@ def to_common(self) -> LabelGenerator: # we need to change the value type of # `_AnnotationGroupTuple.ndjson_objects` to accept a list of objects # and adapt the code to support duplicate UUIDs - assert ndjson_annotation.uuid not in group.ndjson_annotations, f"UUID '{ndjson_annotation.uuid}' is not unique" + assert ( + ndjson_annotation.uuid not in group.ndjson_annotations + ), f"UUID '{ndjson_annotation.uuid}' is not unique" - group.ndjson_annotations[ - ndjson_annotation.uuid] = ndjson_annotation + group.ndjson_annotations[ndjson_annotation.uuid] = ( + ndjson_annotation + ) return LabelGenerator( - data=self._generate_annotations(annotation_groups)) + data=self._generate_annotations(annotation_groups) + ) @classmethod - def from_common(cls, - data: LabelCollection) -> Generator["NDLabel", None, None]: + def from_common( + cls, data: LabelCollection + ) -> Generator["NDLabel", None, None]: for label in data: yield from cls._create_non_video_annotations(label) yield from cls._create_video_annotations(label) @@ -144,68 +205,96 @@ def _generate_annotations( for uuid, ndjson_annotation in group.ndjson_annotations.items(): if isinstance(ndjson_annotation, NDDicomSegments): annotations.extend( - NDDicomSegments.to_common(ndjson_annotation, - ndjson_annotation.name, - ndjson_annotation.schema_id)) + NDDicomSegments.to_common( + ndjson_annotation, + ndjson_annotation.name, + ndjson_annotation.schema_id, + ) + ) elif isinstance(ndjson_annotation, NDSegments): annotations.extend( - NDSegments.to_common(ndjson_annotation, - ndjson_annotation.name, - ndjson_annotation.schema_id)) + NDSegments.to_common( + ndjson_annotation, + ndjson_annotation.name, + ndjson_annotation.schema_id, + ) + ) elif isinstance(ndjson_annotation, NDDicomMasks): annotations.append( - NDDicomMasks.to_common(ndjson_annotation)) + NDDicomMasks.to_common(ndjson_annotation) + ) elif isinstance(ndjson_annotation, NDVideoMasks): annotations.append( - NDVideoMasks.to_common(ndjson_annotation)) + NDVideoMasks.to_common(ndjson_annotation) + ) elif isinstance(ndjson_annotation, NDObjectType.__args__): annotation = NDObject.to_common(ndjson_annotation) annotations.append(annotation) relationship_annotations[uuid] = annotation - elif isinstance(ndjson_annotation, - NDClassificationType.__args__): + elif isinstance( + ndjson_annotation, NDClassificationType.__args__ + ): annotations.extend( - NDClassification.to_common(ndjson_annotation)) - elif isinstance(ndjson_annotation, - (NDScalarMetric, NDConfusionMatrixMetric)): + NDClassification.to_common(ndjson_annotation) + ) + elif isinstance( + ndjson_annotation, (NDScalarMetric, NDConfusionMatrixMetric) + ): annotations.append( - NDMetricAnnotation.to_common(ndjson_annotation)) + NDMetricAnnotation.to_common(ndjson_annotation) + ) elif isinstance(ndjson_annotation, NDPromptClassificationType): - annotation = NDPromptClassification.to_common(ndjson_annotation) + annotation = NDPromptClassification.to_common( + ndjson_annotation + ) annotations.append(annotation) elif isinstance(ndjson_annotation, NDMessageTask): annotations.append(ndjson_annotation.to_common()) else: raise TypeError( - f"Unsupported annotation. {type(ndjson_annotation)}") + f"Unsupported annotation. {type(ndjson_annotation)}" + ) # after all the annotations have been discovered, we can now create # the relationship objects and use references to the objects # involved for relationship in group.relationships: try: - source, target = relationship_annotations[ - relationship.source], relationship_annotations[ - relationship.target] + source, target = ( + relationship_annotations[relationship.source], + relationship_annotations[relationship.target], + ) except KeyError: raise ValueError( f"Relationship object refers to nonexistent object with UUID '{relationship.source}' and/or '{relationship.target}'" ) annotations.append( - NDRelationship.to_common(relationship.ndjson, source, - target)) + NDRelationship.to_common( + relationship.ndjson, source, target + ) + ) - yield Label(annotations=annotations, - data=self._infer_media_type(group.data_row, - annotations)) + yield Label( + annotations=annotations, + data=self._infer_media_type(group.data_row, annotations), + ) def _infer_media_type( - self, data_row: DataRow, - annotations: List[Union[TextEntity, ConversationEntity, - VideoClassificationAnnotation, - DICOMObjectAnnotation, VideoObjectAnnotation, - ObjectAnnotation, ClassificationAnnotation, - ScalarMetric, ConfusionMatrixMetric]] + self, + data_row: DataRow, + annotations: List[ + Union[ + TextEntity, + ConversationEntity, + VideoClassificationAnnotation, + DICOMObjectAnnotation, + VideoObjectAnnotation, + ObjectAnnotation, + ClassificationAnnotation, + ScalarMetric, + ConfusionMatrixMetric, + ] + ], ) -> Union[TextData, VideoData, ImageData]: if len(annotations) == 0: raise ValueError("Missing annotations while inferring media type") @@ -214,7 +303,10 @@ def _infer_media_type( data = GenericDataRowData if (TextEntity in types) or (ConversationEntity in types): data = TextData - elif VideoClassificationAnnotation in types or VideoObjectAnnotation in types: + elif ( + VideoClassificationAnnotation in types + or VideoObjectAnnotation in types + ): data = VideoData elif DICOMObjectAnnotation in types: data = DicomData @@ -226,7 +318,8 @@ def _infer_media_type( @staticmethod def _get_consecutive_frames( - frames_indices: List[int]) -> List[Tuple[int, int]]: + frames_indices: List[int], + ) -> List[Tuple[int, int]]: consecutive = [] for k, g in groupby(enumerate(frames_indices), lambda x: x[0] - x[1]): group = list(map(itemgetter(1), g)) @@ -235,18 +328,23 @@ def _get_consecutive_frames( @classmethod def _get_segment_frame_ranges( - cls, annotation_group: List[Union[VideoClassificationAnnotation, - VideoObjectAnnotation]] + cls, + annotation_group: List[ + Union[VideoClassificationAnnotation, VideoObjectAnnotation] + ], ) -> List[Tuple[int, int]]: - sorted_frame_segment_indices = sorted([ - (annotation.frame, annotation.segment_index) - for annotation in annotation_group - if annotation.segment_index is not None - ]) + sorted_frame_segment_indices = sorted( + [ + (annotation.frame, annotation.segment_index) + for annotation in annotation_group + if annotation.segment_index is not None + ] + ) if len(sorted_frame_segment_indices) == 0: # Group segment by consecutive frames, since `segment_index` is not present return cls._get_consecutive_frames( - sorted([annotation.frame for annotation in annotation_group])) + sorted([annotation.frame for annotation in annotation_group]) + ) elif len(sorted_frame_segment_indices) == len(annotation_group): # Group segment by segment_index last_segment_id = 0 @@ -264,32 +362,34 @@ def _get_segment_frame_ranges( return frame_ranges else: raise ValueError( - f"Video annotations cannot partially have `segment_index` set") + f"Video annotations cannot partially have `segment_index` set" + ) @classmethod def _create_video_annotations( cls, label: Label ) -> Generator[Union[NDChecklistSubclass, NDRadioSubclass], None, None]: - video_annotations = defaultdict(list) for annot in label.annotations: if isinstance( - annot, - (VideoClassificationAnnotation, VideoObjectAnnotation)): - video_annotations[annot.feature_schema_id or - annot.name].append(annot) + annot, (VideoClassificationAnnotation, VideoObjectAnnotation) + ): + video_annotations[annot.feature_schema_id or annot.name].append( + annot + ) elif isinstance(annot, VideoMaskAnnotation): yield NDObject.from_common(annotation=annot, data=label.data) for annotation_group in video_annotations.values(): segment_frame_ranges = cls._get_segment_frame_ranges( - annotation_group) + annotation_group + ) if isinstance(annotation_group[0], VideoClassificationAnnotation): annotation = annotation_group[0] frames_data = [] for frames in segment_frame_ranges: - frames_data.append({'start': frames[0], 'end': frames[-1]}) - annotation.extra.update({'frames': frames_data}) + frames_data.append({"start": frames[0], "end": frames[-1]}) + annotation.extra.update({"frames": frames_data}) yield NDClassification.from_common(annotation, label.data) elif isinstance(annotation_group[0], VideoObjectAnnotation): @@ -297,7 +397,10 @@ def _create_video_annotations( for start_frame, end_frame in segment_frame_ranges: segment = [] for annotation in annotation_group: - if annotation.keyframe and start_frame <= annotation.frame <= end_frame: + if ( + annotation.keyframe + and start_frame <= annotation.frame <= end_frame + ): segment.append(annotation) segments.append(segment) yield NDObject.from_common(segments, label.data) @@ -305,10 +408,16 @@ def _create_video_annotations( @classmethod def _create_non_video_annotations(cls, label: Label): non_video_annotations = [ - annot for annot in label.annotations - if not isinstance(annot, (VideoClassificationAnnotation, - VideoObjectAnnotation, - VideoMaskAnnotation)) + annot + for annot in label.annotations + if not isinstance( + annot, + ( + VideoClassificationAnnotation, + VideoObjectAnnotation, + VideoMaskAnnotation, + ), + ) ] for annotation in non_video_annotations: if isinstance(annotation, ClassificationAnnotation): diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py index 9fd90544c..60d538b19 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py @@ -3,11 +3,17 @@ from labelbox.data.annotation_types.data import ImageData, TextData from labelbox.data.serialization.ndjson.base import DataRow, NDJsonBase from labelbox.data.annotation_types.metrics.scalar import ( - ScalarMetric, ScalarMetricAggregation, ScalarMetricValue, - ScalarMetricConfidenceValue) + ScalarMetric, + ScalarMetricAggregation, + ScalarMetricValue, + ScalarMetricConfidenceValue, +) from labelbox.data.annotation_types.metrics.confusion_matrix import ( - ConfusionMatrixAggregation, ConfusionMatrixMetric, - ConfusionMatrixMetricValue, ConfusionMatrixMetricConfidenceValue) + ConfusionMatrixAggregation, + ConfusionMatrixMetric, + ConfusionMatrixMetricValue, + ConfusionMatrixMetricConfidenceValue, +) from pydantic import ConfigDict, model_serializer from .base import _SubclassRegistryBase @@ -16,71 +22,82 @@ class BaseNDMetric(NDJsonBase): metric_value: float feature_name: Optional[str] = None subclass_name: Optional[str] = None - model_config = ConfigDict(use_enum_values = True) + model_config = ConfigDict(use_enum_values=True) - @model_serializer(mode = "wrap") + @model_serializer(mode="wrap") def serialize_model(self, handler): res = handler(self) - for field in ['featureName', 'subclassName']: + for field in ["featureName", "subclassName"]: if field in res and res[field] is None: res.pop(field) return res class NDConfusionMatrixMetric(BaseNDMetric, _SubclassRegistryBase): - metric_value: Union[ConfusionMatrixMetricValue, - ConfusionMatrixMetricConfidenceValue] + metric_value: Union[ + ConfusionMatrixMetricValue, ConfusionMatrixMetricConfidenceValue + ] metric_name: str aggregation: ConfusionMatrixAggregation def to_common(self) -> ConfusionMatrixMetric: - return ConfusionMatrixMetric(value=self.metric_value, - metric_name=self.metric_name, - feature_name=self.feature_name, - subclass_name=self.subclass_name, - aggregation=self.aggregation, - extra={'uuid': self.uuid}) + return ConfusionMatrixMetric( + value=self.metric_value, + metric_name=self.metric_name, + feature_name=self.feature_name, + subclass_name=self.subclass_name, + aggregation=self.aggregation, + extra={"uuid": self.uuid}, + ) @classmethod def from_common( - cls, metric: ConfusionMatrixMetric, - data: Union[TextData, ImageData]) -> "NDConfusionMatrixMetric": - return cls(uuid=metric.extra.get('uuid'), - metric_value=metric.value, - metric_name=metric.metric_name, - feature_name=metric.feature_name, - subclass_name=metric.subclass_name, - aggregation=metric.aggregation, - data_row=DataRow(id=data.uid, global_key=data.global_key)) + cls, metric: ConfusionMatrixMetric, data: Union[TextData, ImageData] + ) -> "NDConfusionMatrixMetric": + return cls( + uuid=metric.extra.get("uuid"), + metric_value=metric.value, + metric_name=metric.metric_name, + feature_name=metric.feature_name, + subclass_name=metric.subclass_name, + aggregation=metric.aggregation, + data_row=DataRow(id=data.uid, global_key=data.global_key), + ) class NDScalarMetric(BaseNDMetric, _SubclassRegistryBase): metric_value: Union[ScalarMetricValue, ScalarMetricConfidenceValue] metric_name: Optional[str] = None - aggregation: Optional[ScalarMetricAggregation] = ScalarMetricAggregation.ARITHMETIC_MEAN + aggregation: Optional[ScalarMetricAggregation] = ( + ScalarMetricAggregation.ARITHMETIC_MEAN + ) def to_common(self) -> ScalarMetric: - return ScalarMetric(value=self.metric_value, - metric_name=self.metric_name, - feature_name=self.feature_name, - subclass_name=self.subclass_name, - aggregation=self.aggregation, - extra={'uuid': self.uuid}) + return ScalarMetric( + value=self.metric_value, + metric_name=self.metric_name, + feature_name=self.feature_name, + subclass_name=self.subclass_name, + aggregation=self.aggregation, + extra={"uuid": self.uuid}, + ) @classmethod - def from_common(cls, metric: ScalarMetric, - data: Union[TextData, ImageData]) -> "NDScalarMetric": - return cls(uuid=metric.extra.get('uuid'), - metric_value=metric.value, - metric_name=metric.metric_name, - feature_name=metric.feature_name, - subclass_name=metric.subclass_name, - aggregation=metric.aggregation.value, - data_row=DataRow(id=data.uid, global_key=data.global_key)) + def from_common( + cls, metric: ScalarMetric, data: Union[TextData, ImageData] + ) -> "NDScalarMetric": + return cls( + uuid=metric.extra.get("uuid"), + metric_value=metric.value, + metric_name=metric.metric_name, + feature_name=metric.feature_name, + subclass_name=metric.subclass_name, + aggregation=metric.aggregation.value, + data_row=DataRow(id=data.uid, global_key=data.global_key), + ) class NDMetricAnnotation: - @classmethod def to_common( cls, annotation: Union[NDScalarMetric, NDConfusionMatrixMetric] @@ -89,16 +106,16 @@ def to_common( @classmethod def from_common( - cls, annotation: Union[ScalarMetric, - ConfusionMatrixMetric], data: Union[TextData, - ImageData] + cls, + annotation: Union[ScalarMetric, ConfusionMatrixMetric], + data: Union[TextData, ImageData], ) -> Union[NDScalarMetric, NDConfusionMatrixMetric]: obj = cls.lookup_object(annotation) return obj.from_common(annotation, data) @staticmethod def lookup_object( - annotation: Union[ScalarMetric, ConfusionMatrixMetric] + annotation: Union[ScalarMetric, ConfusionMatrixMetric], ) -> Union[Type[NDScalarMetric], Type[NDConfusionMatrixMetric]]: result = { ScalarMetric: NDScalarMetric, @@ -106,5 +123,6 @@ def lookup_object( }.get(type(annotation)) if result is None: raise TypeError( - f"Unable to convert object to MAL format. `{type(annotation)}`") + f"Unable to convert object to MAL format. `{type(annotation)}`" + ) return result diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py index 7b1908b76..4cb797f38 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py @@ -4,17 +4,24 @@ from .base import _SubclassRegistryBase, DataRow, NDAnnotation from ...annotation_types.types import Cuid -from ...annotation_types.mmc import MessageSingleSelectionTask, MessageMultiSelectionTask, MessageRankingTask, MessageEvaluationTaskAnnotation +from ...annotation_types.mmc import ( + MessageSingleSelectionTask, + MessageMultiSelectionTask, + MessageRankingTask, + MessageEvaluationTaskAnnotation, +) class MessageTaskData(_CamelCaseMixin): format: str - data: Union[MessageSingleSelectionTask, MessageMultiSelectionTask, - MessageRankingTask] + data: Union[ + MessageSingleSelectionTask, + MessageMultiSelectionTask, + MessageRankingTask, + ] class NDMessageTask(NDAnnotation, _SubclassRegistryBase): - message_evaluation_task: MessageTaskData def to_common(self) -> MessageEvaluationTaskAnnotation: @@ -27,13 +34,16 @@ def to_common(self) -> MessageEvaluationTaskAnnotation: @classmethod def from_common( - cls, - annotation: MessageEvaluationTaskAnnotation, - data: Any #Union[ImageData, TextData], + cls, + annotation: MessageEvaluationTaskAnnotation, + data: Any, # Union[ImageData, TextData], ) -> "NDMessageTask": - return cls(uuid=str(annotation._uuid), - name=annotation.name, - schema_id=annotation.feature_schema_id, - data_row=DataRow(id=data.uid, global_key=data.global_key), - message_evaluation_task=MessageTaskData( - format=annotation.value.format, data=annotation.value)) + return cls( + uuid=str(annotation._uuid), + name=annotation.name, + schema_id=annotation.feature_schema_id, + data_row=DataRow(id=data.uid, global_key=data.global_key), + message_evaluation_task=MessageTaskData( + format=annotation.value.format, data=annotation.value + ), + ) diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py index 2b32f1c2b..79e9b4adf 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py @@ -2,9 +2,19 @@ from typing import Any, Dict, List, Tuple, Union, Optional import base64 -from labelbox.data.annotation_types.ner.conversation_entity import ConversationEntity -from labelbox.data.annotation_types.video import VideoObjectAnnotation, DICOMObjectAnnotation -from labelbox.data.mixins import ConfidenceMixin, CustomMetricsMixin, CustomMetric, CustomMetricsNotSupportedMixin +from labelbox.data.annotation_types.ner.conversation_entity import ( + ConversationEntity, +) +from labelbox.data.annotation_types.video import ( + VideoObjectAnnotation, + DICOMObjectAnnotation, +) +from labelbox.data.mixins import ( + ConfidenceMixin, + CustomMetricsMixin, + CustomMetric, + CustomMetricsNotSupportedMixin, +) import numpy as np from PIL import Image @@ -13,12 +23,35 @@ from labelbox.data.annotation_types.data.video import VideoData from ...annotation_types.data import ImageData, TextData, MaskData -from ...annotation_types.ner import DocumentEntity, DocumentTextSelection, TextEntity +from ...annotation_types.ner import ( + DocumentEntity, + DocumentTextSelection, + TextEntity, +) from ...annotation_types.types import Cuid -from ...annotation_types.geometry import DocumentRectangle, Rectangle, Polygon, Line, Point, Mask -from ...annotation_types.annotation import ClassificationAnnotation, ObjectAnnotation -from ...annotation_types.video import VideoMaskAnnotation, DICOMMaskAnnotation, MaskFrame, MaskInstance -from .classification import NDClassification, NDSubclassification, NDSubclassificationType +from ...annotation_types.geometry import ( + DocumentRectangle, + Rectangle, + Polygon, + Line, + Point, + Mask, +) +from ...annotation_types.annotation import ( + ClassificationAnnotation, + ObjectAnnotation, +) +from ...annotation_types.video import ( + VideoMaskAnnotation, + DICOMMaskAnnotation, + MaskFrame, + MaskInstance, +) +from .classification import ( + NDClassification, + NDSubclassification, + NDSubclassificationType, +) from .base import DataRow, NDAnnotation, NDJsonBase, _SubclassRegistryBase from pydantic import BaseModel @@ -48,7 +81,9 @@ class Bbox(BaseModel): width: float -class NDPoint(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase): +class NDPoint( + NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase +): point: _Point def to_common(self) -> Point: @@ -56,46 +91,48 @@ def to_common(self) -> Point: @classmethod def from_common( - cls, - uuid: str, - point: Point, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None) -> "NDPoint": - return cls(point={ - 'x': point.x, - 'y': point.y - }, - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) + cls, + uuid: str, + point: Point, + classifications: List[ClassificationAnnotation], + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + data: Union[ImageData, TextData], + confidence: Optional[float] = None, + custom_metrics: Optional[List[CustomMetric]] = None, + ) -> "NDPoint": + return cls( + point={"x": point.x, "y": point.y}, + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + classifications=classifications, + confidence=confidence, + custom_metrics=custom_metrics, + ) class NDFramePoint(VideoSupported, _SubclassRegistryBase): point: _Point classifications: List[NDSubclassificationType] = [] - def to_common(self, name: str, feature_schema_id: Cuid, - segment_index: int) -> VideoObjectAnnotation: - return VideoObjectAnnotation(frame=self.frame, - segment_index=segment_index, - keyframe=True, - name=name, - feature_schema_id=feature_schema_id, - value=Point(x=self.point.x, - y=self.point.y), - classifications=[ - NDSubclassification.to_common(annot) - for annot in self.classifications - ]) + def to_common( + self, name: str, feature_schema_id: Cuid, segment_index: int + ) -> VideoObjectAnnotation: + return VideoObjectAnnotation( + frame=self.frame, + segment_index=segment_index, + keyframe=True, + name=name, + feature_schema_id=feature_schema_id, + value=Point(x=self.point.x, y=self.point.y), + classifications=[ + NDSubclassification.to_common(annot) + for annot in self.classifications + ], + ) @classmethod def from_common( @@ -104,12 +141,16 @@ def from_common( point: Point, classifications: List[NDSubclassificationType], ): - return cls(frame=frame, - point=_Point(x=point.x, y=point.y), - classifications=classifications) + return cls( + frame=frame, + point=_Point(x=point.x, y=point.y), + classifications=classifications, + ) -class NDLine(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase): +class NDLine( + NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase +): line: List[_Point] def to_common(self) -> Line: @@ -117,35 +158,36 @@ def to_common(self) -> Line: @classmethod def from_common( - cls, - uuid: str, - line: Line, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None) -> "NDLine": - return cls(line=[{ - 'x': pt.x, - 'y': pt.y - } for pt in line.points], - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) + cls, + uuid: str, + line: Line, + classifications: List[ClassificationAnnotation], + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + data: Union[ImageData, TextData], + confidence: Optional[float] = None, + custom_metrics: Optional[List[CustomMetric]] = None, + ) -> "NDLine": + return cls( + line=[{"x": pt.x, "y": pt.y} for pt in line.points], + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + classifications=classifications, + confidence=confidence, + custom_metrics=custom_metrics, + ) class NDFrameLine(VideoSupported, _SubclassRegistryBase): line: List[_Point] classifications: List[NDSubclassificationType] = [] - def to_common(self, name: str, feature_schema_id: Cuid, - segment_index: int) -> VideoObjectAnnotation: + def to_common( + self, name: str, feature_schema_id: Cuid, segment_index: int + ) -> VideoObjectAnnotation: return VideoObjectAnnotation( frame=self.frame, segment_index=segment_index, @@ -156,7 +198,8 @@ def to_common(self, name: str, feature_schema_id: Cuid, classifications=[ NDSubclassification.to_common(annot) for annot in self.classifications - ]) + ], + ) @classmethod def from_common( @@ -165,18 +208,21 @@ def from_common( line: Line, classifications: List[NDSubclassificationType], ): - return cls(frame=frame, - line=[{ - 'x': pt.x, - 'y': pt.y - } for pt in line.points], - classifications=classifications) + return cls( + frame=frame, + line=[{"x": pt.x, "y": pt.y} for pt in line.points], + classifications=classifications, + ) class NDDicomLine(NDFrameLine, _SubclassRegistryBase): - - def to_common(self, name: str, feature_schema_id: Cuid, segment_index: int, - group_key: str) -> DICOMObjectAnnotation: + def to_common( + self, + name: str, + feature_schema_id: Cuid, + segment_index: int, + group_key: str, + ) -> DICOMObjectAnnotation: return DICOMObjectAnnotation( frame=self.frame, segment_index=segment_index, @@ -184,10 +230,13 @@ def to_common(self, name: str, feature_schema_id: Cuid, segment_index: int, name=name, feature_schema_id=feature_schema_id, value=Line(points=[Point(x=pt.x, y=pt.y) for pt in self.line]), - group_key=group_key) + group_key=group_key, + ) -class NDPolygon(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase): +class NDPolygon( + NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase +): polygon: List[_Point] def to_common(self) -> Polygon: @@ -195,63 +244,73 @@ def to_common(self) -> Polygon: @classmethod def from_common( - cls, - uuid: str, - polygon: Polygon, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None) -> "NDPolygon": - return cls(polygon=[{ - 'x': pt.x, - 'y': pt.y - } for pt in polygon.points], - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) - - -class NDRectangle(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase): + cls, + uuid: str, + polygon: Polygon, + classifications: List[ClassificationAnnotation], + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + data: Union[ImageData, TextData], + confidence: Optional[float] = None, + custom_metrics: Optional[List[CustomMetric]] = None, + ) -> "NDPolygon": + return cls( + polygon=[{"x": pt.x, "y": pt.y} for pt in polygon.points], + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + classifications=classifications, + confidence=confidence, + custom_metrics=custom_metrics, + ) + + +class NDRectangle( + NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase +): bbox: Bbox def to_common(self) -> Rectangle: - return Rectangle(start=Point(x=self.bbox.left, y=self.bbox.top), - end=Point(x=self.bbox.left + self.bbox.width, - y=self.bbox.top + self.bbox.height)) + return Rectangle( + start=Point(x=self.bbox.left, y=self.bbox.top), + end=Point( + x=self.bbox.left + self.bbox.width, + y=self.bbox.top + self.bbox.height, + ), + ) @classmethod def from_common( - cls, - uuid: str, - rectangle: Rectangle, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None + cls, + uuid: str, + rectangle: Rectangle, + classifications: List[ClassificationAnnotation], + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + data: Union[ImageData, TextData], + confidence: Optional[float] = None, + custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDRectangle": - return cls(bbox=Bbox(top=min(rectangle.start.y, rectangle.end.y), - left=min(rectangle.start.x, rectangle.end.x), - height=abs(rectangle.end.y - rectangle.start.y), - width=abs(rectangle.end.x - rectangle.start.x)), - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - page=extra.get('page'), - unit=extra.get('unit'), - confidence=confidence, - custom_metrics=custom_metrics) + return cls( + bbox=Bbox( + top=min(rectangle.start.y, rectangle.end.y), + left=min(rectangle.start.x, rectangle.end.x), + height=abs(rectangle.end.y - rectangle.start.y), + width=abs(rectangle.end.x - rectangle.start.x), + ), + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + classifications=classifications, + page=extra.get("page"), + unit=extra.get("unit"), + confidence=confidence, + custom_metrics=custom_metrics, + ) class NDDocumentRectangle(NDRectangle, _SubclassRegistryBase): @@ -259,59 +318,73 @@ class NDDocumentRectangle(NDRectangle, _SubclassRegistryBase): unit: str def to_common(self) -> DocumentRectangle: - return DocumentRectangle(start=Point(x=self.bbox.left, y=self.bbox.top), - end=Point(x=self.bbox.left + self.bbox.width, - y=self.bbox.top + self.bbox.height), - page=self.page, - unit=self.unit) + return DocumentRectangle( + start=Point(x=self.bbox.left, y=self.bbox.top), + end=Point( + x=self.bbox.left + self.bbox.width, + y=self.bbox.top + self.bbox.height, + ), + page=self.page, + unit=self.unit, + ) @classmethod def from_common( - cls, - uuid: str, - rectangle: DocumentRectangle, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None + cls, + uuid: str, + rectangle: DocumentRectangle, + classifications: List[ClassificationAnnotation], + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + data: Union[ImageData, TextData], + confidence: Optional[float] = None, + custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDRectangle": - return cls(bbox=Bbox(top=min(rectangle.start.y, rectangle.end.y), - left=min(rectangle.start.x, rectangle.end.x), - height=abs(rectangle.end.y - rectangle.start.y), - width=abs(rectangle.end.x - rectangle.start.x)), - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - page=rectangle.page, - unit=rectangle.unit.value, - confidence=confidence, - custom_metrics=custom_metrics) + return cls( + bbox=Bbox( + top=min(rectangle.start.y, rectangle.end.y), + left=min(rectangle.start.x, rectangle.end.x), + height=abs(rectangle.end.y - rectangle.start.y), + width=abs(rectangle.end.x - rectangle.start.x), + ), + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + classifications=classifications, + page=rectangle.page, + unit=rectangle.unit.value, + confidence=confidence, + custom_metrics=custom_metrics, + ) class NDFrameRectangle(VideoSupported, _SubclassRegistryBase): bbox: Bbox classifications: List[NDSubclassificationType] = [] - def to_common(self, name: str, feature_schema_id: Cuid, - segment_index: int) -> VideoObjectAnnotation: + def to_common( + self, name: str, feature_schema_id: Cuid, segment_index: int + ) -> VideoObjectAnnotation: return VideoObjectAnnotation( frame=self.frame, segment_index=segment_index, keyframe=True, name=name, feature_schema_id=feature_schema_id, - value=Rectangle(start=Point(x=self.bbox.left, y=self.bbox.top), - end=Point(x=self.bbox.left + self.bbox.width, - y=self.bbox.top + self.bbox.height)), + value=Rectangle( + start=Point(x=self.bbox.left, y=self.bbox.top), + end=Point( + x=self.bbox.left + self.bbox.width, + y=self.bbox.top + self.bbox.height, + ), + ), classifications=[ NDSubclassification.to_common(annot) for annot in self.classifications - ]) + ], + ) @classmethod def from_common( @@ -320,12 +393,16 @@ def from_common( rectangle: Rectangle, classifications: List[NDSubclassificationType], ): - return cls(frame=frame, - bbox=Bbox(top=min(rectangle.start.y, rectangle.end.y), - left=min(rectangle.start.x, rectangle.end.x), - height=abs(rectangle.end.y - rectangle.start.y), - width=abs(rectangle.end.x - rectangle.start.x)), - classifications=classifications) + return cls( + frame=frame, + bbox=Bbox( + top=min(rectangle.start.y, rectangle.end.y), + left=min(rectangle.start.x, rectangle.end.x), + height=abs(rectangle.end.y - rectangle.start.y), + width=abs(rectangle.end.x - rectangle.start.x), + ), + classifications=classifications, + ) class NDSegment(BaseModel): @@ -343,19 +420,25 @@ def lookup_segment_object_type(segment: List) -> "NDFrameObjectType": return result @staticmethod - def segment_with_uuid(keyframe: Union[NDFrameRectangle, NDFramePoint, - NDFrameLine], uuid: str): + def segment_with_uuid( + keyframe: Union[NDFrameRectangle, NDFramePoint, NDFrameLine], uuid: str + ): keyframe._uuid = uuid - keyframe.extra = {'uuid': uuid} + keyframe.extra = {"uuid": uuid} return keyframe - def to_common(self, name: str, feature_schema_id: Cuid, uuid: str, - segment_index: int): + def to_common( + self, name: str, feature_schema_id: Cuid, uuid: str, segment_index: int + ): return [ self.segment_with_uuid( - keyframe.to_common(name=name, - feature_schema_id=feature_schema_id, - segment_index=segment_index), uuid) + keyframe.to_common( + name=name, + feature_schema_id=feature_schema_id, + segment_index=segment_index, + ), + uuid, + ) for keyframe in self.keyframes ] @@ -363,14 +446,19 @@ def to_common(self, name: str, feature_schema_id: Cuid, uuid: str, def from_common(cls, segment): nd_frame_object_type = cls.lookup_segment_object_type(segment) - return cls(keyframes=[ - nd_frame_object_type.from_common( - object_annotation.frame, object_annotation.value, [ - NDSubclassification.from_common(annot) - for annot in object_annotation.classifications - ]) - for object_annotation in segment - ]) + return cls( + keyframes=[ + nd_frame_object_type.from_common( + object_annotation.frame, + object_annotation.value, + [ + NDSubclassification.from_common(annot) + for annot in object_annotation.classifications + ], + ) + for object_annotation in segment + ] + ) class NDDicomSegment(NDSegment): @@ -384,16 +472,26 @@ def lookup_segment_object_type(segment: List) -> "NDDicomObjectType": if segment_class == Line: return NDDicomLine else: - raise ValueError('DICOM segments only support Line objects') + raise ValueError("DICOM segments only support Line objects") - def to_common(self, name: str, feature_schema_id: Cuid, uuid: str, - segment_index: int, group_key: str): + def to_common( + self, + name: str, + feature_schema_id: Cuid, + uuid: str, + segment_index: int, + group_key: str, + ): return [ self.segment_with_uuid( - keyframe.to_common(name=name, - feature_schema_id=feature_schema_id, - segment_index=segment_index, - group_key=group_key), uuid) + keyframe.to_common( + name=name, + feature_schema_id=feature_schema_id, + segment_index=segment_index, + group_key=group_key, + ), + uuid, + ) for keyframe in self.keyframes ] @@ -405,24 +503,33 @@ def to_common(self, name: str, feature_schema_id: Cuid): result = [] for idx, segment in enumerate(self.segments): result.extend( - segment.to_common(name=name, - feature_schema_id=feature_schema_id, - segment_index=idx, - uuid=self.uuid)) + segment.to_common( + name=name, + feature_schema_id=feature_schema_id, + segment_index=idx, + uuid=self.uuid, + ) + ) return result @classmethod - def from_common(cls, segments: List[VideoObjectAnnotation], data: VideoData, - name: str, feature_schema_id: Cuid, - extra: Dict[str, Any]) -> "NDSegments": - + def from_common( + cls, + segments: List[VideoObjectAnnotation], + data: VideoData, + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + ) -> "NDSegments": segments = [NDSegment.from_common(segment) for segment in segments] - return cls(segments=segments, - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=extra.get('uuid')) + return cls( + segments=segments, + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=extra.get("uuid"), + ) class NDDicomSegments(NDBaseObject, DicomSupported, _SubclassRegistryBase): @@ -432,26 +539,36 @@ def to_common(self, name: str, feature_schema_id: Cuid): result = [] for idx, segment in enumerate(self.segments): result.extend( - segment.to_common(name=name, - feature_schema_id=feature_schema_id, - segment_index=idx, - uuid=self.uuid, - group_key=self.group_key)) + segment.to_common( + name=name, + feature_schema_id=feature_schema_id, + segment_index=idx, + uuid=self.uuid, + group_key=self.group_key, + ) + ) return result @classmethod - def from_common(cls, segments: List[DICOMObjectAnnotation], data: VideoData, - name: str, feature_schema_id: Cuid, extra: Dict[str, Any], - group_key: str) -> "NDDicomSegments": - + def from_common( + cls, + segments: List[DICOMObjectAnnotation], + data: VideoData, + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + group_key: str, + ) -> "NDDicomSegments": segments = [NDDicomSegment.from_common(segment) for segment in segments] - return cls(segments=segments, - dataRow=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=extra.get('uuid'), - group_key=group_key) + return cls( + segments=segments, + dataRow=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=extra.get("uuid"), + group_key=group_key, + ) class _URIMask(BaseModel): @@ -463,53 +580,61 @@ class _PNGMask(BaseModel): png: str -class NDMask(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase): +class NDMask( + NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase +): mask: Union[_URIMask, _PNGMask] def to_common(self) -> Mask: if isinstance(self.mask, _URIMask): - return Mask(mask=MaskData(url=self.mask.instanceURI), - color=self.mask.colorRGB) + return Mask( + mask=MaskData(url=self.mask.instanceURI), + color=self.mask.colorRGB, + ) else: - encoded_image_bytes = self.mask.png.encode('utf-8') + encoded_image_bytes = self.mask.png.encode("utf-8") image_bytes = base64.b64decode(encoded_image_bytes) image = np.array(Image.open(BytesIO(image_bytes))) if np.max(image) > 1: raise ValueError( - f"Expected binary mask. Found max value of {np.max(image)}") + f"Expected binary mask. Found max value of {np.max(image)}" + ) # Color is 1,1,1 because it is a binary array and we are just stacking it into 3 channels return Mask(mask=MaskData.from_2D_arr(image), color=(1, 1, 1)) @classmethod def from_common( - cls, - uuid: str, - mask: Mask, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None) -> "NDMask": - + cls, + uuid: str, + mask: Mask, + classifications: List[ClassificationAnnotation], + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + data: Union[ImageData, TextData], + confidence: Optional[float] = None, + custom_metrics: Optional[List[CustomMetric]] = None, + ) -> "NDMask": if mask.mask.url is not None: lbv1_mask = _URIMask(instanceURI=mask.mask.url, colorRGB=mask.color) else: binary = np.all(mask.mask.value == mask.color, axis=-1) im_bytes = BytesIO() - Image.fromarray(binary, 'L').save(im_bytes, format="PNG") + Image.fromarray(binary, "L").save(im_bytes, format="PNG") lbv1_mask = _PNGMask( - png=base64.b64encode(im_bytes.getvalue()).decode('utf-8')) + png=base64.b64encode(im_bytes.getvalue()).decode("utf-8") + ) - return cls(mask=lbv1_mask, - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) + return cls( + mask=lbv1_mask, + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + classifications=classifications, + confidence=confidence, + custom_metrics=custom_metrics, + ) class NDVideoMasksFramesInstances(BaseModel): @@ -517,14 +642,20 @@ class NDVideoMasksFramesInstances(BaseModel): instances: List[MaskInstance] -class NDVideoMasks(NDJsonBase, ConfidenceMixin, CustomMetricsNotSupportedMixin, _SubclassRegistryBase): +class NDVideoMasks( + NDJsonBase, + ConfidenceMixin, + CustomMetricsNotSupportedMixin, + _SubclassRegistryBase, +): masks: NDVideoMasksFramesInstances def to_common(self) -> VideoMaskAnnotation: for mask_frame in self.masks.frames: if mask_frame.im_bytes: mask_frame.im_bytes = base64.b64decode( - mask_frame.im_bytes.encode('utf-8')) + mask_frame.im_bytes.encode("utf-8") + ) return VideoMaskAnnotation( frames=self.masks.frames, @@ -536,17 +667,18 @@ def from_common(cls, annotation, data): for mask_frame in annotation.frames: if mask_frame.im_bytes: mask_frame.im_bytes = base64.b64encode( - mask_frame.im_bytes).decode('utf-8') + mask_frame.im_bytes + ).decode("utf-8") return cls( data_row=DataRow(id=data.uid, global_key=data.global_key), - masks=NDVideoMasksFramesInstances(frames=annotation.frames, - instances=annotation.instances), + masks=NDVideoMasksFramesInstances( + frames=annotation.frames, instances=annotation.instances + ), ) class NDDicomMasks(NDVideoMasks, DicomSupported, _SubclassRegistryBase): - def to_common(self) -> DICOMMaskAnnotation: return DICOMMaskAnnotation( frames=self.masks.frames, @@ -558,8 +690,9 @@ def to_common(self) -> DICOMMaskAnnotation: def from_common(cls, annotation, data): return cls( data_row=DataRow(id=data.uid, global_key=data.global_key), - masks=NDVideoMasksFramesInstances(frames=annotation.frames, - instances=annotation.instances), + masks=NDVideoMasksFramesInstances( + frames=annotation.frames, instances=annotation.instances + ), group_key=annotation.group_key.value, ) @@ -569,7 +702,9 @@ class Location(BaseModel): end: int -class NDTextEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase): +class NDTextEntity( + NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase +): location: Location def to_common(self) -> TextEntity: @@ -577,37 +712,42 @@ def to_common(self) -> TextEntity: @classmethod def from_common( - cls, - uuid: str, - text_entity: TextEntity, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None + cls, + uuid: str, + text_entity: TextEntity, + classifications: List[ClassificationAnnotation], + name: str, + feature_schema_id: Cuid, + extra: Dict[str, Any], + data: Union[ImageData, TextData], + confidence: Optional[float] = None, + custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDTextEntity": - return cls(location=Location( - start=text_entity.start, - end=text_entity.end, - ), - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) - - -class NDDocumentEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase): + return cls( + location=Location( + start=text_entity.start, + end=text_entity.end, + ), + data_row=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + classifications=classifications, + confidence=confidence, + custom_metrics=custom_metrics, + ) + + +class NDDocumentEntity( + NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase +): name: str text_selections: List[DocumentTextSelection] def to_common(self) -> DocumentEntity: - return DocumentEntity(name=self.name, - text_selections=self.text_selections) + return DocumentEntity( + name=self.name, text_selections=self.text_selections + ) @classmethod def from_common( @@ -620,26 +760,29 @@ def from_common( extra: Dict[str, Any], data: Union[ImageData, TextData], confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None + custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDDocumentEntity": - - return cls(text_selections=document_entity.text_selections, - dataRow=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) + return cls( + text_selections=document_entity.text_selections, + dataRow=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + classifications=classifications, + confidence=confidence, + custom_metrics=custom_metrics, + ) class NDConversationEntity(NDTextEntity, _SubclassRegistryBase): message_id: str def to_common(self) -> ConversationEntity: - return ConversationEntity(start=self.location.start, - end=self.location.end, - message_id=self.message_id) + return ConversationEntity( + start=self.location.start, + end=self.location.end, + message_id=self.message_id, + ) @classmethod def from_common( @@ -652,22 +795,24 @@ def from_common( extra: Dict[str, Any], data: Union[ImageData, TextData], confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None + custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDConversationEntity": - return cls(location=Location(start=conversation_entity.start, - end=conversation_entity.end), - message_id=conversation_entity.message_id, - dataRow=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) + return cls( + location=Location( + start=conversation_entity.start, end=conversation_entity.end + ), + message_id=conversation_entity.message_id, + dataRow=DataRow(id=data.uid, global_key=data.global_key), + name=name, + schema_id=feature_schema_id, + uuid=uuid, + classifications=classifications, + confidence=confidence, + custom_metrics=custom_metrics, + ) class NDObject: - @staticmethod def to_common(annotation: "NDObjectType") -> ObjectAnnotation: common_annotation = annotation.to_common() @@ -675,49 +820,66 @@ def to_common(annotation: "NDObjectType") -> ObjectAnnotation: NDSubclassification.to_common(annot) for annot in annotation.classifications ] - confidence = annotation.confidence if hasattr(annotation, - 'confidence') else None - - custom_metrics = annotation.custom_metrics if hasattr( - annotation, 'custom_metrics') else None - return ObjectAnnotation(value=common_annotation, - name=annotation.name, - feature_schema_id=annotation.schema_id, - classifications=classifications, - extra={ - 'uuid': annotation.uuid, - 'page': annotation.page, - 'unit': annotation.unit - }, - confidence=confidence, - custom_metrics=custom_metrics) + confidence = ( + annotation.confidence if hasattr(annotation, "confidence") else None + ) + + custom_metrics = ( + annotation.custom_metrics + if hasattr(annotation, "custom_metrics") + else None + ) + return ObjectAnnotation( + value=common_annotation, + name=annotation.name, + feature_schema_id=annotation.schema_id, + classifications=classifications, + extra={ + "uuid": annotation.uuid, + "page": annotation.page, + "unit": annotation.unit, + }, + confidence=confidence, + custom_metrics=custom_metrics, + ) @classmethod def from_common( cls, - annotation: Union[ObjectAnnotation, List[List[VideoObjectAnnotation]], - VideoMaskAnnotation], data: Union[ImageData, TextData] - ) -> Union[NDLine, NDPoint, NDPolygon, NDDocumentRectangle, NDRectangle, - NDMask, NDTextEntity]: + annotation: Union[ + ObjectAnnotation, + List[List[VideoObjectAnnotation]], + VideoMaskAnnotation, + ], + data: Union[ImageData, TextData], + ) -> Union[ + NDLine, + NDPoint, + NDPolygon, + NDDocumentRectangle, + NDRectangle, + NDMask, + NDTextEntity, + ]: obj = cls.lookup_object(annotation) # if it is video segments - if (obj == NDSegments or obj == NDDicomSegments): - + if obj == NDSegments or obj == NDDicomSegments: first_video_annotation = annotation[0][0] args = dict( segments=annotation, data=data, name=first_video_annotation.name, feature_schema_id=first_video_annotation.feature_schema_id, - extra=first_video_annotation.extra) + extra=first_video_annotation.extra, + ) if isinstance(first_video_annotation, DICOMObjectAnnotation): group_key = first_video_annotation.group_key.value args.update(dict(group_key=group_key)) return obj.from_common(**args) - elif (obj == NDVideoMasks or obj == NDDicomMasks): + elif obj == NDVideoMasks or obj == NDDicomMasks: return obj.from_common(annotation, data) subclasses = [ @@ -725,21 +887,27 @@ def from_common( for annot in annotation.classifications ] optional_kwargs = {} - if (annotation.confidence): - optional_kwargs['confidence'] = annotation.confidence - - if (annotation.custom_metrics): - optional_kwargs['custom_metrics'] = annotation.custom_metrics - - return obj.from_common(str(annotation._uuid), annotation.value, - subclasses, annotation.name, - annotation.feature_schema_id, annotation.extra, - data, **optional_kwargs) + if annotation.confidence: + optional_kwargs["confidence"] = annotation.confidence + + if annotation.custom_metrics: + optional_kwargs["custom_metrics"] = annotation.custom_metrics + + return obj.from_common( + str(annotation._uuid), + annotation.value, + subclasses, + annotation.name, + annotation.feature_schema_id, + annotation.extra, + data, + **optional_kwargs, + ) @staticmethod def lookup_object( - annotation: Union[ObjectAnnotation, List]) -> "NDObjectType": - + annotation: Union[ObjectAnnotation, List], + ) -> "NDObjectType": if isinstance(annotation, DICOMMaskAnnotation): result = NDDicomMasks elif isinstance(annotation, VideoMaskAnnotation): @@ -772,9 +940,18 @@ def lookup_object( ) return result + NDEntityType = Union[NDConversationEntity, NDTextEntity] -NDObjectType = Union[NDLine, NDPolygon, NDPoint, NDDocumentRectangle, - NDRectangle, NDMask, NDEntityType, NDDocumentEntity] +NDObjectType = Union[ + NDLine, + NDPolygon, + NDPoint, + NDDocumentRectangle, + NDRectangle, + NDMask, + NDEntityType, + NDDocumentEntity, +] NDFrameObjectType = NDFrameRectangle, NDFramePoint, NDFrameLine NDDicomObjectType = NDDicomLine diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py index 1cdb23b76..fbea7e477 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py @@ -20,25 +20,36 @@ class NDRelationship(NDAnnotation, _SubclassRegistryBase): relationship: _Relationship @staticmethod - def to_common(annotation: "NDRelationship", source: SUPPORTED_ANNOTATIONS, - target: SUPPORTED_ANNOTATIONS) -> RelationshipAnnotation: - return RelationshipAnnotation(name=annotation.name, - value=Relationship( - source=source, - target=target, - type=Relationship.Type( - annotation.relationship.type)), - extra={'uuid': annotation.uuid}, - feature_schema_id=annotation.schema_id) + def to_common( + annotation: "NDRelationship", + source: SUPPORTED_ANNOTATIONS, + target: SUPPORTED_ANNOTATIONS, + ) -> RelationshipAnnotation: + return RelationshipAnnotation( + name=annotation.name, + value=Relationship( + source=source, + target=target, + type=Relationship.Type(annotation.relationship.type), + ), + extra={"uuid": annotation.uuid}, + feature_schema_id=annotation.schema_id, + ) @classmethod - def from_common(cls, annotation: RelationshipAnnotation, - data: Union[ImageData, TextData]) -> "NDRelationship": + def from_common( + cls, + annotation: RelationshipAnnotation, + data: Union[ImageData, TextData], + ) -> "NDRelationship": relationship = annotation.value - return cls(uuid=str(annotation._uuid), - name=annotation.name, - dataRow=DataRow(id=data.uid, global_key=data.global_key), - relationship=_Relationship( - source=str(relationship.source._uuid), - target=str(relationship.target._uuid), - type=relationship.type.value)) + return cls( + uuid=str(annotation._uuid), + name=annotation.name, + dataRow=DataRow(id=data.uid, global_key=data.global_key), + relationship=_Relationship( + source=str(relationship.source._uuid), + target=str(relationship.target._uuid), + type=relationship.type.value, + ), + ) diff --git a/libs/labelbox/src/labelbox/exceptions.py b/libs/labelbox/src/labelbox/exceptions.py index 048ca0757..34cfeaf4d 100644 --- a/libs/labelbox/src/labelbox/exceptions.py +++ b/libs/labelbox/src/labelbox/exceptions.py @@ -21,16 +21,18 @@ def __str__(self): class AuthenticationError(LabelboxError): """Raised when an API key fails authentication.""" + pass class AuthorizationError(LabelboxError): """Raised when a user is unauthorized to perform the given request.""" + pass class ResourceNotFoundError(LabelboxError): - """Exception raised when a given resource is not found. """ + """Exception raised when a given resource is not found.""" def __init__(self, db_object_type=None, params=None, message=None): """Constructor for the ResourceNotFoundException class. @@ -43,14 +45,17 @@ def __init__(self, db_object_type=None, params=None, message=None): if message is not None: super().__init__(message) else: - super().__init__("Resource '%s' not found for params: %r" % - (db_object_type.type_name(), params)) + super().__init__( + "Resource '%s' not found for params: %r" + % (db_object_type.type_name(), params) + ) self.db_object_type = db_object_type self.params = params class ResourceConflict(LabelboxError): - """Exception raised when a given resource conflicts with another. """ + """Exception raised when a given resource conflicts with another.""" + pass @@ -58,6 +63,7 @@ class ValidationFailedError(LabelboxError): """Exception raised for when a GraphQL query fails validation (query cost, etc.) E.g. a query that is too expensive, or depth is too deep. """ + pass @@ -68,25 +74,29 @@ class InternalServerError(LabelboxError): TODO: these errors need better messages from platform """ + pass class InvalidQueryError(LabelboxError): - """ Indicates a malconstructed or unsupported query (either by GraphQL in + """Indicates a malconstructed or unsupported query (either by GraphQL in general or by Labelbox specifically). This can be the result of either client - or server side query validation. """ + or server side query validation.""" + pass class UnprocessableEntityError(LabelboxError): - """ Indicates that a resource could not be created in the server side + """Indicates that a resource could not be created in the server side due to a validation or transaction error""" + pass class ResourceCreationError(LabelboxError): - """ Indicates that a resource could not be created in the server side + """Indicates that a resource could not be created in the server side due to a validation or transaction error""" + pass @@ -100,33 +110,39 @@ def __init__(self, cause): class TimeoutError(LabelboxError): """Raised when a request times-out.""" + pass class InvalidAttributeError(LabelboxError): - """ Raised when a field (name or Field instance) is not valid or found - for a specific DB object type. """ + """Raised when a field (name or Field instance) is not valid or found + for a specific DB object type.""" def __init__(self, db_object_type, field): - super().__init__("Field(s) '%r' not valid on DB type '%s'" % - (field, db_object_type.type_name())) + super().__init__( + "Field(s) '%r' not valid on DB type '%s'" + % (field, db_object_type.type_name()) + ) self.db_object_type = db_object_type self.field = field class ApiLimitError(LabelboxError): - """ Raised when the user performs too many requests in a short period - of time. """ + """Raised when the user performs too many requests in a short period + of time.""" + pass class MalformedQueryException(Exception): - """ Raised when the user submits a malformed query.""" + """Raised when the user submits a malformed query.""" + pass class UuidError(LabelboxError): - """ Raised when there are repeat Uuid's in bulk import request.""" + """Raised when there are repeat Uuid's in bulk import request.""" + pass @@ -136,16 +152,19 @@ class InconsistentOntologyException(Exception): class MALValidationError(LabelboxError): """Raised when user input is invalid for MAL imports.""" + pass class OperationNotAllowedException(Exception): """Raised when user does not have permissions to a resource or has exceeded usage limit""" + pass class OperationNotSupportedException(Exception): """Raised when sdk does not support requested operation""" + pass diff --git a/libs/labelbox/src/labelbox/orm/comparison.py b/libs/labelbox/src/labelbox/orm/comparison.py index 91c226652..7830549ea 100644 --- a/libs/labelbox/src/labelbox/orm/comparison.py +++ b/libs/labelbox/src/labelbox/orm/comparison.py @@ -1,4 +1,5 @@ from enum import Enum, auto + """ Classes for defining the client-side comparison operations used for filtering data in fetches. Intended for use by library internals and not by the end user. @@ -6,7 +7,7 @@ class LogicalExpressionComponent: - """ Implements bitwise logical operator methods (&, | and ~) so they + """Implements bitwise logical operator methods (&, | and ~) so they return a LogicalExpression object containing this LogicalExpressionComponent. """ @@ -26,22 +27,23 @@ def __invert__(self): class LogicalExpression(LogicalExpressionComponent): - """ A unary (NOT) or binary (AND, OR) logical expression between - Comparison or LogicalExpression objects. """ + """A unary (NOT) or binary (AND, OR) logical expression between + Comparison or LogicalExpression objects.""" class Op(Enum): - """ Type of logical operation. """ + """Type of logical operation.""" + AND = auto() OR = auto() NOT = auto() def __call__(self, first, second=None): - """ Forwards to LogicalExpression constructor, passing `self` - as the `op` argument. """ + """Forwards to LogicalExpression constructor, passing `self` + as the `op` argument.""" return LogicalExpression(self, first, second) def __init__(self, op, first, second=None): - """ LogicalExpression constructor. + """LogicalExpression constructor. Args: op (LogicalExpression.Op): The type of logical operation. @@ -54,12 +56,14 @@ def __init__(self, op, first, second=None): def __eq__(self, other): return self.op == other.op and ( - (self.first == other.first and self.second == other.second) or - (self.first == other.second and self.second == other.first)) + (self.first == other.first and self.second == other.second) + or (self.first == other.second and self.second == other.first) + ) def __hash__(self): - return hash( - self.op) + 2833 * hash(self.first) + 2837 * hash(self.second) + return ( + hash(self.op) + 2833 * hash(self.first) + 2837 * hash(self.second) + ) def __repr__(self): return "%r %s %r" % (self.first, self.op.name, self.second) @@ -69,11 +73,12 @@ def __str__(self): class Comparison(LogicalExpressionComponent): - """ A comparison between a database value (represented by a - `labelbox.schema.Field` object) and a constant value. """ + """A comparison between a database value (represented by a + `labelbox.schema.Field` object) and a constant value.""" class Op(Enum): - """ Type of the comparison operation. """ + """Type of the comparison operation.""" + EQ = auto() NE = auto() LT = auto() @@ -82,12 +87,12 @@ class Op(Enum): GE = auto() def __call__(self, *args): - """ Forwards to Comparison constructor, passing `self` - as the `op` argument. """ + """Forwards to Comparison constructor, passing `self` + as the `op` argument.""" return Comparison(self, *args) def __init__(self, op, field, value): - """ Comparison constructor. + """Comparison constructor. Args: op (Comparison.Op): The type of comparison. @@ -99,8 +104,11 @@ def __init__(self, op, field, value): self.value = value def __eq__(self, other): - return self.op == other.op and \ - self.field == other.field and self.value == other.value + return ( + self.op == other.op + and self.field == other.field + and self.value == other.value + ) def __hash__(self): return hash(self.op) + 2861 * hash(self.field) + 2927 * hash(self.value) diff --git a/libs/labelbox/src/labelbox/orm/db_object.py b/libs/labelbox/src/labelbox/orm/db_object.py index c4f87eac5..b210a8a5b 100644 --- a/libs/labelbox/src/labelbox/orm/db_object.py +++ b/libs/labelbox/src/labelbox/orm/db_object.py @@ -5,7 +5,11 @@ import json from labelbox import utils -from labelbox.exceptions import InvalidQueryError, InvalidAttributeError, OperationNotSupportedException +from labelbox.exceptions import ( + InvalidQueryError, + InvalidAttributeError, + OperationNotSupportedException, +) from labelbox.orm import query from labelbox.orm.model import Field, Relationship, Entity from labelbox.pagination import PaginatedCollection @@ -14,7 +18,7 @@ class DbObject(Entity): - """ A client-side representation of a database object (row). Intended as + """A client-side representation of a database object (row). Intended as base class for classes representing concrete database types (for example a Project). Exposes support functionalities so that the concrete subclass definition be as simple and DRY as possible. It should come down to just @@ -35,7 +39,7 @@ class DbObject(Entity): """ def __init__(self, client, field_values): - """ Constructor of a database object. Generally it should only be used + """Constructor of a database object. Generally it should only be used by library internals and not by the end user. Args: @@ -49,12 +53,16 @@ def __init__(self, client, field_values): value = field_values.get(utils.camel_case(relationship.name)) if relationship.cache and value is None: raise KeyError( - f"Expected field values for {relationship.name}") - setattr(self, relationship.name, - RelationshipManager(self, relationship, value)) + f"Expected field values for {relationship.name}" + ) + setattr( + self, + relationship.name, + RelationshipManager(self, relationship, value), + ) def _set_field_values(self, field_values): - """ Sets field values on this object. Ensures proper value conversions. + """Sets field values on this object. Ensures proper value conversions. Args: field_values (dict): Maps field names (GraphQL variant, snakeCase) to values. *Must* contain all field values for this object's @@ -69,7 +77,10 @@ def _set_field_values(self, field_values): except ValueError: logger.warning( "Failed to convert value '%s' to datetime for " - "field %s", value, field) + "field %s", + value, + field, + ) elif isinstance(field.field_type, Field.EnumType): value = field.field_type.enum_cls(value) elif isinstance(field.field_type, Field.ListType): @@ -80,7 +91,9 @@ def _set_field_values(self, field_values): except ValueError: logger.warning( "Failed to convert value '%s' to metadata for field %s", - value, field) + value, + field, + ) setattr(self, field.name, value) def __repr__(self): @@ -94,29 +107,34 @@ def __str__(self): attribute_values = { field.name: getattr(self, field.name) for field in self.fields() } - return "<%s %s>" % (self.type_name().split(".")[-1], - json.dumps(attribute_values, indent=4, default=str)) + return "<%s %s>" % ( + self.type_name().split(".")[-1], + json.dumps(attribute_values, indent=4, default=str), + ) def __eq__(self, other): - return (isinstance(other, DbObject) and - self.type_name() == other.type_name() and self.uid == other.uid) + return ( + isinstance(other, DbObject) + and self.type_name() == other.type_name() + and self.uid == other.uid + ) def __hash__(self): return 7541 * hash(self.type_name()) + hash(self.uid) class RelationshipManager: - """ Manages relationships (object fetching and updates) for a `DbObject` + """Manages relationships (object fetching and updates) for a `DbObject` instance. There is one RelationshipManager for each relationship in each `DbObject` instance. """ def __init__(self, source, relationship, value=None): """Args: - source (DbObject subclass instance): The object that's the source - of the relationship. - relationship (labelbox.schema.Relationship): The relationship - schema descriptor object. + source (DbObject subclass instance): The object that's the source + of the relationship. + relationship (labelbox.schema.Relationship): The relationship + schema descriptor object. """ self.source = source self.relationship = relationship @@ -127,8 +145,8 @@ def __init__(self, source, relationship, value=None): self.config = relationship.config def __call__(self, *args, **kwargs): - """ Forwards the call to either `_to_many` or `_to_one` methods, - depending on relationship type. """ + """Forwards the call to either `_to_many` or `_to_one` methods, + depending on relationship type.""" if self.relationship.deprecation_warning: logger.warning(self.relationship.deprecation_warning) @@ -139,7 +157,7 @@ def __call__(self, *args, **kwargs): return self._to_one(*args, **kwargs) def _to_many(self, where=None, order_by=None): - """ Returns an iterable over the destination relationship objects. + """Returns an iterable over the destination relationship objects. Args: where (None, Comparison or LogicalExpression): Filtering clause. order_by (None or (Field, Field.Order)): Ordering clause. @@ -149,27 +167,35 @@ def _to_many(self, where=None, order_by=None): rel = self.relationship if where is not None and not self.supports_filtering: raise InvalidQueryError( - "Relationship %s.%s doesn't support filtering" % - (self.source.type_name(), rel.name)) + "Relationship %s.%s doesn't support filtering" + % (self.source.type_name(), rel.name) + ) if order_by is not None and not self.supports_sorting: raise InvalidQueryError( - "Relationship %s.%s doesn't support sorting" % - (self.source.type_name(), rel.name)) + "Relationship %s.%s doesn't support sorting" + % (self.source.type_name(), rel.name) + ) if rel.filter_deleted: not_deleted = rel.destination_type.deleted == False where = not_deleted if where is None else where & not_deleted query_string, params = query.relationship( - self.source if self.filter_on_id else type(self.source), rel, where, - order_by) + self.source if self.filter_on_id else type(self.source), + rel, + where, + order_by, + ) return PaginatedCollection( - self.source.client, query_string, params, + self.source.client, + query_string, + params, [utils.camel_case(self.source.type_name()), rel.graphql_name], - rel.destination_type) + rel.destination_type, + ) def _to_one(self): - """ Returns the relationship destination object. """ + """Returns the relationship destination object.""" rel = self.relationship if self.value: @@ -178,7 +204,8 @@ def _to_one(self): query_string, params = query.relationship(self.source, rel, None, None) result = self.source.client.execute(query_string, params) result = result and result.get( - utils.camel_case(type(self.source).type_name())) + utils.camel_case(type(self.source).type_name()) + ) result = result and result.get(rel.graphql_name) if result is None: return None @@ -186,26 +213,28 @@ def _to_one(self): return rel.destination_type(self.source.client, result) def connect(self, other): - """ Connects source object of this manager to the `other` object. """ + """Connects source object of this manager to the `other` object.""" query_string, params = query.update_relationship( - self.source, other, self.relationship, "connect") + self.source, other, self.relationship, "connect" + ) self.source.client.execute(query_string, params) def disconnect(self, other): - """ Disconnects source object of this manager from the `other` object. """ + """Disconnects source object of this manager from the `other` object.""" if not self.config.disconnect_supported: raise OperationNotSupportedException( - "Disconnect is not supported for this relationship") + "Disconnect is not supported for this relationship" + ) query_string, params = query.update_relationship( - self.source, other, self.relationship, "disconnect") + self.source, other, self.relationship, "disconnect" + ) self.source.client.execute(query_string, params) class Updateable: - def update(self, **kwargs): - """ Updates this DB object with new values. Values should be + """Updates this DB object with new values. Values should be passed as key-value arguments with field names as keys: >>> db_object.update(name="New name", title="A title") @@ -229,10 +258,10 @@ def update(self, **kwargs): class Deletable: - """ Implements deletion for objects that have a `deleted` attribute. """ + """Implements deletion for objects that have a `deleted` attribute.""" def delete(self): - """ Deletes this DB object from the DB (server side). After + """Deletes this DB object from the DB (server side). After a call to this you should not use this DB object anymore. """ query_string, params = query.delete(self) @@ -240,7 +269,7 @@ def delete(self): class BulkDeletable: - """ Implements deletion for objects that have a custom, bulk deletion + """Implements deletion for objects that have a custom, bulk deletion mutation (accepts a list of IDs of objects to be deleted). A subclass must override the `bulk_delete` static method so it @@ -263,13 +292,14 @@ def _bulk_delete(objects, use_where_clause): types = {type(o) for o in objects} if len(types) != 1: raise InvalidQueryError( - "Can't bulk-delete objects of different types: %r" % types) + "Can't bulk-delete objects of different types: %r" % types + ) query_str, params = query.bulk_delete(objects, use_where_clause) objects[0].client.execute(query_str, params) def delete(self): - """ Deletes this DB object from the DB (server side). After + """Deletes this DB object from the DB (server side). After a call to this you should not use this DB object anymore. """ type(self).bulk_delete([self]) @@ -295,7 +325,8 @@ def wrapper(*args, **kwargs): else: raise ValueError( f"Static method {fn.__name__} must have a client passed in as the first " - f"argument or as a keyword argument.") + f"argument or as a keyword argument." + ) wrapped_fn = fn.__func__ else: client = args[0].client @@ -306,7 +337,8 @@ def wrapper(*args, **kwargs): f"This function {fn.__name__} relies on a experimental feature in the api. " f"This means that the interface could change. " f"Set `enable_experimental=True` in the client to enable use of " - f"experimental functions.") + f"experimental functions." + ) return wrapped_fn(*args, **kwargs) return wrapper diff --git a/libs/labelbox/src/labelbox/orm/model.py b/libs/labelbox/src/labelbox/orm/model.py index 5720b67cc..84dcac774 100644 --- a/libs/labelbox/src/labelbox/orm/model.py +++ b/libs/labelbox/src/labelbox/orm/model.py @@ -6,13 +6,14 @@ from labelbox import utils from labelbox.exceptions import InvalidAttributeError from labelbox.orm.comparison import Comparison + """ Defines Field, Relationship and Entity. These classes are building blocks for defining the Labelbox schema, DB object operations and queries. """ class Field: - """ Represents a field in a database table. A Field has a name, a type + """Represents a field in a database table. A Field has a name, a type (corresponds to server-side GraphQL type) and a server-side name. The server-side name is most often just a camelCase version of the client-side snake_case name. @@ -48,7 +49,6 @@ class Type(Enum): Json = auto() class EnumType: - def __init__(self, enum_cls: type): self.enum_cls = enum_cls @@ -57,7 +57,7 @@ def name(self): return self.enum_cls.__name__ class ListType: - """ Represents Field that is a list of some object. + """Represents Field that is a list of some object. Args: list_cls (type): Type of object that list is made of. graphql_type (str): Inner object's graphql type. @@ -76,7 +76,8 @@ def name(self): return f"[{self.graphql_type}]" class Order(Enum): - """ Type of sort ordering. """ + """Type of sort ordering.""" + Asc = auto() Desc = auto() @@ -116,12 +117,14 @@ def Json(*args): def List(list_cls: type, graphql_type=None, **kwargs): return Field(Field.ListType(list_cls, graphql_type), **kwargs) - def __init__(self, - field_type: Union[Type, EnumType, ListType], - name, - graphql_name=None, - result_subquery=None): - """ Field init. + def __init__( + self, + field_type: Union[Type, EnumType, ListType], + name, + graphql_name=None, + result_subquery=None, + ): + """Field init. Args: field_type (Field.Type): The type of the field. name (str): client-side Python attribute name of a database @@ -140,7 +143,7 @@ def __init__(self, @property def asc(self): - """ Property that resolves to tuple (Field, Field.Order). + """Property that resolves to tuple (Field, Field.Order). Used for easy definition of sort ordering: >>> projects_ordered = client.get_projects(order_by=Project.name.asc) """ @@ -148,14 +151,14 @@ def asc(self): @property def desc(self): - """ Property that resolves to tuple (Field, Field.Order). + """Property that resolves to tuple (Field, Field.Order). Used for easy definition of sort ordering: >>> projects_ordered = client.get_projects(order_by=Project.name.desc) """ return (self, Field.Order.Desc) def __eq__(self, other): - """ Equality of Fields has two meanings. If comparing to a Field object, + """Equality of Fields has two meanings. If comparing to a Field object, then a boolean indicator if the fields are identical is returned. If comparing to any other type, a Comparison object is created. """ @@ -165,7 +168,7 @@ def __eq__(self, other): return Comparison.Op.EQ(self, other) def __ne__(self, other): - """ Equality of Fields has two meanings. If comparing to a Field object, + """Equality of Fields has two meanings. If comparing to a Field object, then a boolean indicator if the fields are identical is returned. If comparing to any other type, a Comparison object is created. """ @@ -199,7 +202,7 @@ def __repr__(self): class Relationship: - """ Represents a relationship in a database table. + """Represents a relationship in a database table. Attributes: relationship_type (Relationship.Type): Indicator if to-one or to-many @@ -236,15 +239,17 @@ def ToOne(*args, **kwargs): def ToMany(*args, **kwargs): return Relationship(Relationship.Type.ToMany, *args, **kwargs) - def __init__(self, - relationship_type, - destination_type_name, - filter_deleted=True, - name=None, - graphql_name=None, - cache=False, - deprecation_warning=None, - config=Config()): + def __init__( + self, + relationship_type, + destination_type_name, + filter_deleted=True, + name=None, + graphql_name=None, + cache=False, + deprecation_warning=None, + config=Config(), + ): self.relationship_type = relationship_type self.destination_type_name = destination_type_name self.filter_deleted = filter_deleted @@ -254,7 +259,8 @@ def __init__(self, if name is None: name = utils.snake_case(destination_type_name) + ( - "s" if relationship_type == Relationship.Type.ToMany else "") + "s" if relationship_type == Relationship.Type.ToMany else "" + ) self.name = name if graphql_name is None: @@ -273,10 +279,11 @@ def __repr__(self): class EntityMeta(type): - """ Entity metaclass. Registers Entity subclasses as attributes + """Entity metaclass. Registers Entity subclasses as attributes of the Entity class object so they can be referenced for example like: Entity.Project. """ + # Maps Entity name to Relationships for all currently defined Entities relationship_mappings: Dict[str, List[Relationship]] = {} @@ -288,14 +295,16 @@ def __init__(cls, clsname, superclasses, attributedict): cls.validate_cached_relationships() if clsname != "Entity": setattr(Entity, clsname, cls) - EntityMeta.relationship_mappings[utils.snake_case( - cls.__name__)] = cls.relationships() + EntityMeta.relationship_mappings[utils.snake_case(cls.__name__)] = ( + cls.relationships() + ) @staticmethod def raise_for_nested_cache(first: str, middle: str, last: List[str]): raise TypeError( "Cannot cache a relationship to an Entity with its own cached relationship(s). " - f"`{first}` caches `{middle}` which caches `{last}`") + f"`{first}` caches `{middle}` which caches `{last}`" + ) @staticmethod def cached_entities(entity_name: str): @@ -329,8 +338,11 @@ def validate_cached_relationships(cls): for rel in cached_rels: nested = cls.cached_entities(rel.name) if nested: - cls.raise_for_nested_cache(utils.snake_case(cls.__name__), - rel.name, list(nested.keys())) + cls.raise_for_nested_cache( + utils.snake_case(cls.__name__), + rel.name, + list(nested.keys()), + ) # If the current Entity (cls) has any cached relationships (cached_rels) # then no other defined Entity (entities in EntityMeta.relationship_mappings) can cache this Entity. @@ -347,12 +359,13 @@ def validate_cached_relationships(cls): cls.raise_for_nested_cache( utils.snake_case(entity_name), utils.snake_case(cls.__name__), - [entity.name for entity in cached_rels]) + [entity.name for entity in cached_rels], + ) class Entity(metaclass=EntityMeta): - """ An entity that contains fields and relationships. Base class - for DbObject (which is base class for concrete schema classes). """ + """An entity that contains fields and relationships. Base class + for DbObject (which is base class for concrete schema classes).""" # Every Entity has an "id" and a "deleted" field # Name the "id" field "uid" in Python to avoid conflict with keyword. @@ -392,7 +405,7 @@ class Entity(metaclass=EntityMeta): @classmethod def _attributes_of_type(cls, attr_type): - """ Yields all the attributes in `cls` of the given `attr_type`. """ + """Yields all the attributes in `cls` of the given `attr_type`.""" for attr_name in dir(cls): attr = getattr(cls, attr_name) if isinstance(attr, attr_type): @@ -400,7 +413,7 @@ def _attributes_of_type(cls, attr_type): @classmethod def fields(cls): - """ Returns a generator that yields all the Fields declared in a + """Returns a generator that yields all the Fields declared in a concrete subclass. """ for attr in cls._attributes_of_type(Field): @@ -409,14 +422,14 @@ def fields(cls): @classmethod def relationships(cls): - """ Returns a generator that yields all the Relationships declared in + """Returns a generator that yields all the Relationships declared in a concrete subclass. """ return cls._attributes_of_type(Relationship) @classmethod def field(cls, field_name): - """ Returns a Field object for the given name. + """Returns a Field object for the given name. Args: field_name (str): Field name, Python (snake-case) convention. Return: @@ -432,7 +445,7 @@ def field(cls, field_name): @classmethod def attribute(cls, attribute_name): - """ Returns a Field or a Relationship object for the given name. + """Returns a Field or a Relationship object for the given name. Args: attribute_name (str): Field or Relationship name, Python (snake-case) convention. @@ -449,7 +462,7 @@ def attribute(cls, attribute_name): @classmethod def type_name(cls): - """ Returns this DB object type name in TitleCase. For example: - Project, DataRow, ... + """Returns this DB object type name in TitleCase. For example: + Project, DataRow, ... """ return cls.__name__.split(".")[-1] diff --git a/libs/labelbox/src/labelbox/orm/query.py b/libs/labelbox/src/labelbox/orm/query.py index f28714d09..8fa9fea00 100644 --- a/libs/labelbox/src/labelbox/orm/query.py +++ b/libs/labelbox/src/labelbox/orm/query.py @@ -2,14 +2,19 @@ from typing import Any, Dict from labelbox import utils -from labelbox.exceptions import InvalidQueryError, InvalidAttributeError, MalformedQueryException +from labelbox.exceptions import ( + InvalidQueryError, + InvalidAttributeError, + MalformedQueryException, +) from labelbox.orm.comparison import LogicalExpression, Comparison from labelbox.orm.model import Field, Relationship, Entity + """ Common query creation functionality. """ def format_param_declaration(params): - """ Formats the parameters dictionary into a declaration of GraphQL + """Formats the parameters dictionary into a declaration of GraphQL query parameters. Args: @@ -27,12 +32,18 @@ def attr_type(attr): else: return Field.Type.ID.name - return "(" + ", ".join("$%s: %s!" % (param, attr_type(attr)) - for param, (_, attr) in params.items()) + ")" + return ( + "(" + + ", ".join( + "$%s: %s!" % (param, attr_type(attr)) + for param, (_, attr) in params.items() + ) + + ")" + ) def results_query_part(entity): - """ Generates the results part of the query. The results contain + """Generates the results part of the query. The results contain all the entity's fields as well as prefetched relationships. Note that this is a recursive function. If there is a cycle in the @@ -44,30 +55,30 @@ def results_query_part(entity): # Query for fields fields = [ field.result_subquery - if field.result_subquery is not None else field.graphql_name + if field.result_subquery is not None + else field.graphql_name for field in entity.fields() ] # Query for cached relationships - fields.extend([ - Query(rel.graphql_name, rel.destination_type).format()[0] - for rel in entity.relationships() - if rel.cache - ]) + fields.extend( + [ + Query(rel.graphql_name, rel.destination_type).format()[0] + for rel in entity.relationships() + if rel.cache + ] + ) return " ".join(fields) class Query: - """ A data structure used during the construction of a query. Supports - subquery (also Query object) nesting for relationship. """ - - def __init__(self, - what, - subquery, - where=None, - paginate=False, - order_by=None): - """ Initializer. + """A data structure used during the construction of a query. Supports + subquery (also Query object) nesting for relationship.""" + + def __init__( + self, what, subquery, where=None, paginate=False, order_by=None + ): + """Initializer. Args: what (str): What is being queried. Typically an object type in singular or plural (i.e. "project" or "projects"). @@ -88,7 +99,7 @@ def __init__(self, self.order_by = order_by def format_subquery(self): - """ Formats the subquery (a Query or Entity subtype). """ + """Formats the subquery (a Query or Entity subtype).""" if isinstance(self.subquery, Query): return self.subquery.format() elif issubclass(self.subquery, Entity): @@ -97,14 +108,14 @@ def format_subquery(self): raise MalformedQueryException() def format_clauses(self, params): - """ Formats the where, order_by and pagination clauses. + """Formats the where, order_by and pagination clauses. Args: params (dict): The current parameter dictionary. """ def format_where(node): - """ Helper that resursively constructs a where clause from a - LogicalExpression tree (leaf nodes are Comparisons). """ + """Helper that resursively constructs a where clause from a + LogicalExpression tree (leaf nodes are Comparisons).""" COMPARISON_TO_SUFFIX = { Comparison.Op.EQ: "", Comparison.Op.NE: "_not", @@ -117,23 +128,29 @@ def format_where(node): if isinstance(node, Comparison): param_name = "param_%d" % len(params) params[param_name] = (node.value, node.field) - return "{%s%s: $%s}" % (node.field.graphql_name, - COMPARISON_TO_SUFFIX[node.op], - param_name) + return "{%s%s: $%s}" % ( + node.field.graphql_name, + COMPARISON_TO_SUFFIX[node.op], + param_name, + ) if node.op == LogicalExpression.Op.NOT: return "{NOT: [%s]}" % format_where(node.first) - return "{%s: [%s, %s]}" % (node.op.name.upper(), - format_where(node.first), - format_where(node.second)) + return "{%s: [%s, %s]}" % ( + node.op.name.upper(), + format_where(node.first), + format_where(node.second), + ) paginate = "skip: %d first: %d" if self.paginate else "" where = "where: %s" % format_where(self.where) if self.where else "" if self.order_by: - order_by = "orderBy: %s_%s" % (self.order_by[0].graphql_name, - self.order_by[1].name.upper()) + order_by = "orderBy: %s_%s" % ( + self.order_by[0].graphql_name, + self.order_by[1].name.upper(), + ) else: order_by = "" @@ -141,7 +158,7 @@ def format_where(node): return "(" + clauses + ")" if clauses else "" def format(self): - """ Formats the full query but without "query" prefix, query name + """Formats the full query but without "query" prefix, query name and parameter declaration. Return: (str, dict) tuple. str is the query and dict maps parameter @@ -153,7 +170,7 @@ def format(self): return query, params def format_top(self, name): - """ Formats the full query including "query" prefix, query name + """Formats the full query including "query" prefix, query name and parameter declaration. The result of this function can be sent to the Client object for execution. @@ -171,7 +188,7 @@ def format_top(self, name): def get_single(entity, uid): - """ Constructs the query and params dict for obtaining a single object. Either + """Constructs the query and params dict for obtaining a single object. Either on ID, or without params. Args: entity (type): An Entity subtype being obtained. @@ -181,12 +198,13 @@ def get_single(entity, uid): """ type_name = entity.type_name() where = entity.uid == uid if uid else None - return Query(utils.camel_case(type_name), entity, - where).format_top("Get" + type_name) + return Query(utils.camel_case(type_name), entity, where).format_top( + "Get" + type_name + ) def logical_ops(where): - """ Returns a generator that yields all the logical operator + """Returns a generator that yields all the logical operator type objects (`LogicalExpression.Op` instances) from a where clause. @@ -203,7 +221,7 @@ def logical_ops(where): def check_where_clause(entity, where): - """ Checks the `where` clause of a query. A `where` clause is legal + """Checks the `where` clause of a query. A `where` clause is legal if it only refers to fields found in the entity it's defined for. Since only AND logical operations are supported server-side at the moment, logical OR and NOT are illegal. @@ -217,7 +235,7 @@ def check_where_clause(entity, where): """ def fields(where): - """ Yields all the fields in a `where` clause. """ + """Yields all the fields in a `where` clause.""" if isinstance(where, LogicalExpression): for f in chain(fields(where.first), fields(where.second)): yield f @@ -233,15 +251,18 @@ def fields(where): if len(set(where_fields)) != len(where_fields): raise InvalidQueryError( "Where clause contains multiple comparisons for " - "the same field: %r." % where) + "the same field: %r." % where + ) if set(logical_ops(where)) not in (set(), {LogicalExpression.Op.AND}): - raise InvalidQueryError("Currently only AND logical ops are allowed in " - "the where clause of a query.") + raise InvalidQueryError( + "Currently only AND logical ops are allowed in " + "the where clause of a query." + ) def check_order_by_clause(entity, order_by): - """ Checks that the `order_by` clause field is a part of `entity`. + """Checks that the `order_by` clause field is a part of `entity`. Args: entity (type): An Entity subclass type. @@ -257,7 +278,7 @@ def check_order_by_clause(entity, order_by): def get_all(entity, where): - """ Constructs a query that fetches all items of the given type. The + """Constructs a query that fetches all items of the given type. The resulting query is intended to be used for pagination, it contains two python-string int-placeholders (%d) for 'skip' and 'first' pagination parameters. @@ -276,7 +297,7 @@ def get_all(entity, where): def relationship(source, relationship, where, order_by): - """ Constructs a query that fetches all items from a -to-many + """Constructs a query that fetches all items from a -to-many relationship. To be used like: >>> project = ... >>> query_str, params = relationship(Project, "datasets", Dataset) @@ -304,17 +325,24 @@ def relationship(source, relationship, where, order_by): check_where_clause(relationship.destination_type, where) check_order_by_clause(relationship.destination_type, order_by) to_many = relationship.relationship_type == Relationship.Type.ToMany - subquery = Query(relationship.graphql_name, relationship.destination_type, - where, to_many, order_by) - query_where = type(source).uid == source.uid if isinstance(source, Entity) \ - else None + subquery = Query( + relationship.graphql_name, + relationship.destination_type, + where, + to_many, + order_by, + ) + query_where = ( + type(source).uid == source.uid if isinstance(source, Entity) else None + ) query = Query(utils.camel_case(source.type_name()), subquery, query_where) - return query.format_top("Get" + source.type_name() + - utils.title_case(relationship.graphql_name)) + return query.format_top( + "Get" + source.type_name() + utils.title_case(relationship.graphql_name) + ) def create(entity, data): - """ Generates a query and parameters for creating a new DB object. + """Generates a query and parameters for creating a new DB object. Args: entity (type): An Entity subtype indicating which kind of @@ -330,8 +358,10 @@ def format_param_value(attribute, param): if isinstance(attribute, Field): return "%s: $%s" % (attribute.graphql_name, param) else: - return "%s: {connect: {id: $%s}}" % (utils.camel_case( - attribute.graphql_name), param) + return "%s: {connect: {id: $%s}}" % ( + utils.camel_case(attribute.graphql_name), + param, + ) # Convert data to params params = { @@ -339,16 +369,21 @@ def format_param_value(attribute, param): } query_str = """mutation Create%sPyApi%s{create%s(data: {%s}) {%s}} """ % ( - type_name, format_param_declaration(params), type_name, " ".join( + type_name, + format_param_declaration(params), + type_name, + " ".join( format_param_value(attribute, param) - for param, (_, attribute) in params.items()), - results_query_part(entity)) + for param, (_, attribute) in params.items() + ), + results_query_part(entity), + ) return query_str, {name: value for name, (value, _) in params.items()} def update_relationship(a, b, relationship, update): - """ Updates the relationship in DB object `a` to connect or disconnect + """Updates the relationship in DB object `a` to connect or disconnect DB object `b`. Args: @@ -360,8 +395,10 @@ def update_relationship(a, b, relationship, update): Return: (query_string, query_parameters) """ - to_one_disconnect = update == "disconnect" and \ - relationship.relationship_type == Relationship.Type.ToOne + to_one_disconnect = ( + update == "disconnect" + and relationship.relationship_type == Relationship.Type.ToOne + ) a_uid_param = utils.camel_case(type(a).type_name()) + "Id" @@ -375,9 +412,16 @@ def update_relationship(a, b, relationship, update): query_str = """mutation %s%sAnd%sPyApi%s{update%s( where: {id: $%s} data: {%s: {%s: %s}}) {id}} """ % ( - utils.title_case(update), type(a).type_name(), type(b).type_name(), - param_declr, utils.title_case(type(a).type_name()), a_uid_param, - relationship.graphql_name, update, b_query) + utils.title_case(update), + type(a).type_name(), + type(b).type_name(), + param_declr, + utils.title_case(type(a).type_name()), + a_uid_param, + relationship.graphql_name, + update, + b_query, + ) if to_one_disconnect: params = {a_uid_param: a.uid} @@ -388,7 +432,7 @@ def update_relationship(a, b, relationship, update): def update_fields(db_object, values): - """ Creates a query that updates `db_object` fields with the + """Creates a query that updates `db_object` fields with the given values. Args: @@ -400,8 +444,10 @@ def update_fields(db_object, values): """ type_name = db_object.type_name() id_param = "%sId" % type_name - values_str = " ".join("%s: $%s" % (field.graphql_name, field.graphql_name) - for field, _ in values.items()) + values_str = " ".join( + "%s: $%s" % (field.graphql_name, field.graphql_name) + for field, _ in values.items() + ) params = { field.graphql_name: (value, field) for field, value in values.items() } @@ -409,14 +455,19 @@ def update_fields(db_object, values): query_str = """mutation update%sPyApi%s{update%s( where: {id: $%s} data: {%s}) {%s}} """ % ( - utils.title_case(type_name), format_param_declaration(params), - type_name, id_param, values_str, results_query_part(type(db_object))) + utils.title_case(type_name), + format_param_declaration(params), + type_name, + id_param, + values_str, + results_query_part(type(db_object)), + ) return query_str, {name: value for name, (value, _) in params.items()} def delete(db_object): - """ Generates a query that deletes the given `db_object` from the DB. + """Generates a query that deletes the given `db_object` from the DB. Args: db_object (DbObject): The DB object being deleted. @@ -424,14 +475,17 @@ def delete(db_object): id_param = "%sId" % db_object.type_name() query_str = """mutation delete%sPyApi%s{update%s( where: {id: $%s} data: {deleted: true}) {id}} """ % ( - db_object.type_name(), "($%s: ID!)" % id_param, db_object.type_name(), - id_param) + db_object.type_name(), + "($%s: ID!)" % id_param, + db_object.type_name(), + id_param, + ) return query_str, {id_param: db_object.uid} def bulk_delete(db_objects, use_where_clause): - """ Generates a query that bulk-deletes the given `db_objects` from the + """Generates a query that bulk-deletes the given `db_objects` from the DB. Args: @@ -441,13 +495,17 @@ def bulk_delete(db_objects, use_where_clause): """ type_name = db_objects[0].type_name() if use_where_clause: - query_str = "mutation delete%ssPyApi{delete%ss(where: {%sIds: [%s]}){id}}" + query_str = ( + "mutation delete%ssPyApi{delete%ss(where: {%sIds: [%s]}){id}}" + ) else: query_str = "mutation delete%ssPyApi{delete%ss(%sIds: [%s]){id}}" query_str = query_str % ( - utils.title_case(type_name), utils.title_case(type_name), - utils.camel_case(type_name), ", ".join( - '"%s"' % db_object.uid for db_object in db_objects)) + utils.title_case(type_name), + utils.title_case(type_name), + utils.camel_case(type_name), + ", ".join('"%s"' % db_object.uid for db_object in db_objects), + ) return query_str, {} diff --git a/libs/labelbox/src/labelbox/pagination.py b/libs/labelbox/src/labelbox/pagination.py index a173505c9..a3b170ec7 100644 --- a/libs/labelbox/src/labelbox/pagination.py +++ b/libs/labelbox/src/labelbox/pagination.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from typing import TYPE_CHECKING + if TYPE_CHECKING: from labelbox import Client from labelbox.orm.db_object import DbObject @@ -11,7 +12,7 @@ class PaginatedCollection: - """ An iterable collection of database objects (Projects, Labels, etc...). + """An iterable collection of database objects (Projects, Labels, etc...). Implements automatic (transparent to the user) paginated fetching during iteration. Intended for use by library internals and not by the end user. @@ -19,15 +20,17 @@ class PaginatedCollection: __init__ map exactly to object attributes. """ - def __init__(self, - client: "Client", - query: str, - params: Dict[str, Union[str, int]], - dereferencing: Union[List[str], Dict[str, Any]], - obj_class: Union[Type["DbObject"], Callable[[Any, Any], Any]], - cursor_path: Optional[List[str]] = None, - experimental: bool = False): - """ Creates a PaginatedCollection. + def __init__( + self, + client: "Client", + query: str, + params: Dict[str, Union[str, int]], + dereferencing: Union[List[str], Dict[str, Any]], + obj_class: Union[Type["DbObject"], Callable[[Any, Any], Any]], + cursor_path: Optional[List[str]] = None, + experimental: bool = False, + ): + """Creates a PaginatedCollection. Args: client (labelbox.Client): the client used for fetching data from DB. @@ -48,18 +51,19 @@ def __init__(self, self._data_ind = 0 pagination_kwargs = { - 'client': client, - 'obj_class': obj_class, - 'dereferencing': dereferencing, - 'experimental': experimental, - 'query': query, - 'params': params + "client": client, + "obj_class": obj_class, + "dereferencing": dereferencing, + "experimental": experimental, + "query": query, + "params": params, } - self.paginator = _CursorPagination( - cursor_path, ** - pagination_kwargs) if cursor_path else _OffsetPagination( - **pagination_kwargs) + self.paginator = ( + _CursorPagination(cursor_path, **pagination_kwargs) + if cursor_path + else _OffsetPagination(**pagination_kwargs) + ) def __iter__(self): self._data_ind = 0 @@ -107,10 +111,15 @@ def get_many(self, n: int): class _Pagination(ABC): - - def __init__(self, client: "Client", obj_class: Type["DbObject"], - dereferencing: Dict[str, Any], query: str, - params: Dict[str, Any], experimental: bool): + def __init__( + self, + client: "Client", + obj_class: Type["DbObject"], + dereferencing: Dict[str, Any], + query: str, + params: Dict[str, Any], + experimental: bool, + ): self.client = client self.obj_class = obj_class self.dereferencing = dereferencing @@ -125,16 +134,14 @@ def get_page_data(self, results: Dict[str, Any]) -> List["DbObject"]: return [self.obj_class(self.client, result) for result in results] @abstractmethod - def get_next_page(self) -> Tuple[Dict[str, Any], bool]: - ... + def get_next_page(self) -> Tuple[Dict[str, Any], bool]: ... class _CursorPagination(_Pagination): - def __init__(self, cursor_path: List[str], *args, **kwargs): super().__init__(*args, **kwargs) self.cursor_path = cursor_path - self.next_cursor: Optional[Any] = kwargs.get('params', {}).get('from') + self.next_cursor: Optional[Any] = kwargs.get("params", {}).get("from") def increment_page(self, results: Dict[str, Any]): for path in self.cursor_path: @@ -145,11 +152,11 @@ def fetched_all(self) -> bool: return not self.next_cursor def fetch_results(self) -> Dict[str, Any]: - page_size = self.params.get('first', _PAGE_SIZE) - self.params.update({'from': self.next_cursor, 'first': page_size}) - return self.client.execute(self.query, - self.params, - experimental=self.experimental) + page_size = self.params.get("first", _PAGE_SIZE) + self.params.update({"from": self.next_cursor, "first": page_size}) + return self.client.execute( + self.query, self.params, experimental=self.experimental + ) def get_next_page(self): results = self.fetch_results() @@ -160,7 +167,6 @@ def get_next_page(self): class _OffsetPagination(_Pagination): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._fetched_pages = 0 @@ -173,9 +179,9 @@ def fetched_all(self, n_items: int) -> bool: def fetch_results(self) -> Dict[str, Any]: query = self.query % (self._fetched_pages * _PAGE_SIZE, _PAGE_SIZE) - return self.client.execute(query, - self.params, - experimental=self.experimental) + return self.client.execute( + query, self.params, experimental=self.experimental + ) def get_next_page(self): results = self.fetch_results() diff --git a/libs/labelbox/src/labelbox/parser.py b/libs/labelbox/src/labelbox/parser.py index fab41bb81..8f64adaf4 100644 --- a/libs/labelbox/src/labelbox/parser.py +++ b/libs/labelbox/src/labelbox/parser.py @@ -2,24 +2,23 @@ class NdjsonDecoder(json.JSONDecoder): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def decode(self, s: str, *args, **kwargs): - lines = ','.join(s.splitlines()) + lines = ",".join(s.splitlines()) text = f"[{lines}]" # NOTE: this is a hack to make json.loads work for ndjson return super().decode(text, *args, **kwargs) def loads(ndjson_string, **kwargs) -> list: - kwargs.setdefault('cls', NdjsonDecoder) + kwargs.setdefault("cls", NdjsonDecoder) return json.loads(ndjson_string, **kwargs) def dumps(obj, **kwargs): lines = map(lambda obj: json.dumps(obj, **kwargs), obj) - return '\n'.join(lines) + return "\n".join(lines) def dump(obj, io, **kwargs): diff --git a/libs/labelbox/src/labelbox/schema/__init__.py b/libs/labelbox/src/labelbox/schema/__init__.py index 9f187bf87..03327e0d1 100644 --- a/libs/labelbox/src/labelbox/schema/__init__.py +++ b/libs/labelbox/src/labelbox/schema/__init__.py @@ -26,4 +26,4 @@ import labelbox.schema.identifiable import labelbox.schema.catalog import labelbox.schema.ontology_kind -import labelbox.schema.project_overview \ No newline at end of file +import labelbox.schema.project_overview diff --git a/libs/labelbox/src/labelbox/schema/annotation_import.py b/libs/labelbox/src/labelbox/schema/annotation_import.py index 2d1fd8582..df7f272a3 100644 --- a/libs/labelbox/src/labelbox/schema/annotation_import.py +++ b/libs/labelbox/src/labelbox/schema/annotation_import.py @@ -3,7 +3,16 @@ import logging import os import time -from typing import Any, BinaryIO, Dict, List, Optional, Union, TYPE_CHECKING, cast +from typing import ( + Any, + BinaryIO, + Dict, + List, + Optional, + Union, + TYPE_CHECKING, + cast, +) from collections import defaultdict from google.api_core import retry @@ -16,7 +25,9 @@ from labelbox.orm.db_object import DbObject from labelbox.orm.model import Field, Relationship from labelbox.utils import is_exactly_one_set -from labelbox.schema.confidence_presence_checker import LabelsConfidencePresenceChecker +from labelbox.schema.confidence_presence_checker import ( + LabelsConfidencePresenceChecker, +) from labelbox.schema.enums import AnnotationImportState from labelbox.schema.serialization import serialize_labels @@ -92,14 +103,14 @@ def statuses(self) -> List[Dict[str, Any]]: self.wait_until_done() return self._fetch_remote_ndjson(self.status_file_url) - def wait_till_done(self, - sleep_time_seconds: int = 10, - show_progress: bool = False) -> None: + def wait_till_done( + self, sleep_time_seconds: int = 10, show_progress: bool = False + ) -> None: self.wait_until_done(sleep_time_seconds, show_progress) - def wait_until_done(self, - sleep_time_seconds: int = 10, - show_progress: bool = False) -> None: + def wait_until_done( + self, sleep_time_seconds: int = 10, show_progress: bool = False + ) -> None: """Blocks import job until certain conditions are met. Blocks until the AnnotationImport.state changes either to `AnnotationImportState.FINISHED` or `AnnotationImportState.FAILED`, @@ -108,9 +119,14 @@ def wait_until_done(self, sleep_time_seconds (int): a time to block between subsequent API calls show_progress (bool): should show progress bar """ - pbar = tqdm(total=100, - bar_format="{n}% |{bar}| [{elapsed}, {rate_fmt}{postfix}]" - ) if show_progress else None + pbar = ( + tqdm( + total=100, + bar_format="{n}% |{bar}| [{elapsed}, {rate_fmt}{postfix}]", + ) + if show_progress + else None + ) while self.state.value == AnnotationImportState.RUNNING.value: logger.info(f"Sleeping for {sleep_time_seconds} seconds...") time.sleep(sleep_time_seconds) @@ -122,9 +138,13 @@ def wait_until_done(self, pbar.update(100 - pbar.n) pbar.close() - @retry.Retry(predicate=retry.if_exception_type( - labelbox.exceptions.ApiLimitError, labelbox.exceptions.TimeoutError, - labelbox.exceptions.NetworkError)) + @retry.Retry( + predicate=retry.if_exception_type( + labelbox.exceptions.ApiLimitError, + labelbox.exceptions.TimeoutError, + labelbox.exceptions.NetworkError, + ) + ) def __backoff_refresh(self) -> None: self.refresh() @@ -145,21 +165,24 @@ def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]: return parser.loads(response.text) @classmethod - def _create_from_bytes(cls, client, variables, query_str, file_name, - bytes_data) -> Dict[str, Any]: + def _create_from_bytes( + cls, client, variables, query_str, file_name, bytes_data + ) -> Dict[str, Any]: operations = json.dumps({"variables": variables, "query": query_str}) data = { "operations": operations, - "map": (None, json.dumps({file_name: ["variables.file"]})) + "map": (None, json.dumps({file_name: ["variables.file"]})), } file_data = (file_name, bytes_data, NDJSON_MIME_TYPE) files = {file_name: file_data} return client.execute(data=data, files=files) @classmethod - def _get_ndjson_from_objects(cls, objects: Union[List[Dict[str, Any]], - List["Label"]], - object_name: str) -> BinaryIO: + def _get_ndjson_from_objects( + cls, + objects: Union[List[Dict[str, Any]], List["Label"]], + object_name: str, + ) -> BinaryIO: if not isinstance(objects, list): raise TypeError( f"{object_name} must be in a form of list. Found {type(objects)}" @@ -173,17 +196,15 @@ def _get_ndjson_from_objects(cls, objects: Union[List[Dict[str, Any]], raise ValueError(f"{object_name} cannot be empty") return data_str.encode( - 'utf-8' + "utf-8" ) # NOTICE this method returns bytes, NOT BinaryIO... should have done io.BytesIO(...) but not going to change this at the moment since it works and fools mypy def refresh(self) -> None: - """Synchronizes values of all fields with the database. - """ + """Synchronizes values of all fields with the database.""" cls = type(self) - res = cls.from_name(self.client, - self.parent_id, - self.name, - as_json=True) + res = cls.from_name( + self.client, self.parent_id, self.name, as_json=True + ) self._set_field_values(res) @classmethod @@ -193,26 +214,29 @@ def _validate_data_rows(cls, objects: List[Dict[str, Any]]): and only one of 'id' or 'globalKey' is provided. Shows up to `max_num_errors` errors if invalidated, to prevent - large number of error messages from being printed out + large number of error messages from being printed out """ errors = [] max_num_errors = 100 labels_per_datarow: Dict[str, Dict[str, int]] = defaultdict( - lambda: defaultdict(int)) + lambda: defaultdict(int) + ) for object in objects: - if 'dataRow' not in object: + if "dataRow" not in object: errors.append(f"'dataRow' is missing in {object}") continue - data_row_object = object['dataRow'] - if not is_exactly_one_set(data_row_object.get('id'), - data_row_object.get('globalKey')): + data_row_object = object["dataRow"] + if not is_exactly_one_set( + data_row_object.get("id"), data_row_object.get("globalKey") + ): errors.append( f"Must provide only one of 'id' or 'globalKey' for 'dataRow' in {object}" ) else: data_row_id = data_row_object.get( - 'globalKey') or data_row_object.get('id') - name = object.get('name') + "globalKey" + ) or data_row_object.get("id") + name = object.get("name") if name: labels_per_datarow[data_row_id][name] += 1 for data_row_id, label_annotations in labels_per_datarow.items(): @@ -224,7 +248,7 @@ def _validate_data_rows(cls, objects: List[Dict[str, Any]]): ) if errors: errors_length = len(errors) - formatted_errors = '\n'.join(errors[:max_num_errors]) + formatted_errors = "\n".join(errors[:max_num_errors]) if errors_length > max_num_errors: logger.warning( f"Found more than {max_num_errors} errors. Showing first {max_num_errors} error messages..." @@ -234,11 +258,13 @@ def _validate_data_rows(cls, objects: List[Dict[str, Any]]): ) @classmethod - def from_name(cls, - client: "labelbox.Client", - parent_id: str, - name: str, - as_json: bool = False): + def from_name( + cls, + client: "labelbox.Client", + parent_id: str, + name: str, + as_json: bool = False, + ): raise NotImplementedError("Inheriting class must override") @property @@ -247,7 +273,6 @@ def parent_id(self) -> str: class CreatableAnnotationImport(AnnotationImport): - @classmethod def create( cls, @@ -256,9 +281,9 @@ def create( name: str, path: Optional[str] = None, url: Optional[str] = None, - labels: Union[List[Dict[str, Any]], List["Label"]] = [] + labels: Union[List[Dict[str, Any]], List["Label"]] = [], ) -> "AnnotationImport": - if (not is_exactly_one_set(url, labels, path)): + if not is_exactly_one_set(url, labels, path): raise ValueError( "Must pass in a nonempty argument for one and only one of the following arguments: url, path, predictions" ) @@ -269,20 +294,25 @@ def create( return cls.create_from_objects(client, id, name, labels) @classmethod - def create_from_url(cls, client: "labelbox.Client", id: str, name: str, - url: str) -> "AnnotationImport": + def create_from_url( + cls, client: "labelbox.Client", id: str, name: str, url: str + ) -> "AnnotationImport": raise NotImplementedError("Inheriting class must override") @classmethod - def create_from_file(cls, client: "labelbox.Client", id: str, name: str, - path: str) -> "AnnotationImport": + def create_from_file( + cls, client: "labelbox.Client", id: str, name: str, path: str + ) -> "AnnotationImport": raise NotImplementedError("Inheriting class must override") @classmethod def create_from_objects( - cls, client: "labelbox.Client", id: str, name: str, - labels: Union[List[Dict[str, Any]], - List["Label"]]) -> "AnnotationImport": + cls, + client: "labelbox.Client", + id: str, + name: str, + labels: Union[List[Dict[str, Any]], List["Label"]], + ) -> "AnnotationImport": raise NotImplementedError("Inheriting class must override") @@ -297,8 +327,9 @@ def parent_id(self) -> str: return self.model_run_id @classmethod - def create_from_file(cls, client: "labelbox.Client", model_run_id: str, - name: str, path: str) -> "MEAPredictionImport": + def create_from_file( + cls, client: "labelbox.Client", model_run_id: str, name: str, path: str + ) -> "MEAPredictionImport": """ Create an MEA prediction import job from a file of annotations @@ -311,17 +342,20 @@ def create_from_file(cls, client: "labelbox.Client", model_run_id: str, MEAPredictionImport """ if os.path.exists(path): - with open(path, 'rb') as f: + with open(path, "rb") as f: return cls._create_mea_import_from_bytes( - client, model_run_id, name, f, - os.stat(path).st_size) + client, model_run_id, name, f, os.stat(path).st_size + ) else: raise ValueError(f"File {path} is not accessible") @classmethod def create_from_objects( - cls, client: "labelbox.Client", model_run_id: str, name, - predictions: Union[List[Dict[str, Any]], List["Label"]] + cls, + client: "labelbox.Client", + model_run_id: str, + name, + predictions: Union[List[Dict[str, Any]], List["Label"]], ) -> "MEAPredictionImport": """ Create an MEA prediction import job from an in memory dictionary @@ -334,14 +368,16 @@ def create_from_objects( Returns: MEAPredictionImport """ - data = cls._get_ndjson_from_objects(predictions, 'annotations') + data = cls._get_ndjson_from_objects(predictions, "annotations") - return cls._create_mea_import_from_bytes(client, model_run_id, name, - data, len(str(data))) + return cls._create_mea_import_from_bytes( + client, model_run_id, name, data, len(str(data)) + ) @classmethod - def create_from_url(cls, client: "labelbox.Client", model_run_id: str, - name: str, url: str) -> "MEAPredictionImport": + def create_from_url( + cls, client: "labelbox.Client", model_run_id: str, name: str, url: str + ) -> "MEAPredictionImport": """ Create an MEA prediction import job from a url The url must point to a file containing prediction annotations. @@ -358,21 +394,26 @@ def create_from_url(cls, client: "labelbox.Client", model_run_id: str, query_str = cls._get_url_mutation() return cls( client, - client.execute(query_str, - params={ - "fileUrl": url, - "modelRunId": model_run_id, - 'name': name - })["createModelErrorAnalysisPredictionImport"]) + client.execute( + query_str, + params={ + "fileUrl": url, + "modelRunId": model_run_id, + "name": name, + }, + )["createModelErrorAnalysisPredictionImport"], + ) else: raise ValueError(f"Url {url} is not reachable") @classmethod - def from_name(cls, - client: "labelbox.Client", - model_run_id: str, - name: str, - as_json: bool = False) -> "MEAPredictionImport": + def from_name( + cls, + client: "labelbox.Client", + model_run_id: str, + name: str, + as_json: bool = False, + ) -> "MEAPredictionImport": """ Retrieves an MEA import job. @@ -395,7 +436,8 @@ def from_name(cls, response = client.execute(query_str, params) if response is None: raise labelbox.exceptions.ResourceNotFoundError( - MEAPredictionImport, params) + MEAPredictionImport, params + ) response = response["modelErrorAnalysisPredictionImport"] if as_json: return response @@ -421,14 +463,19 @@ def _get_file_mutation(cls) -> str: @classmethod def _create_mea_import_from_bytes( - cls, client: "labelbox.Client", model_run_id: str, name: str, - bytes_data: BinaryIO, content_len: int) -> "MEAPredictionImport": + cls, + client: "labelbox.Client", + model_run_id: str, + name: str, + bytes_data: BinaryIO, + content_len: int, + ) -> "MEAPredictionImport": file_name = f"{model_run_id}__{name}.ndjson" variables = { "file": None, "contentLength": content_len, "modelRunId": model_run_id, - "name": name + "name": name, } query_str = cls._get_file_mutation() res = cls._create_from_bytes( @@ -452,10 +499,14 @@ def parent_id(self) -> str: return self.project().uid @classmethod - def create_for_model_run_data_rows(cls, client: "labelbox.Client", - model_run_id: str, - data_row_ids: List[str], project_id: str, - name: str) -> "MEAToMALPredictionImport": + def create_for_model_run_data_rows( + cls, + client: "labelbox.Client", + model_run_id: str, + data_row_ids: List[str], + project_id: str, + name: str, + ) -> "MEAToMALPredictionImport": """ Create an MEA to MAL prediction import job from a list of data row ids of a specific model run @@ -469,20 +520,25 @@ def create_for_model_run_data_rows(cls, client: "labelbox.Client", query_str = cls._get_model_run_data_rows_mutation() return cls( client, - client.execute(query_str, - params={ - "dataRowIds": data_row_ids, - "modelRunId": model_run_id, - "projectId": project_id, - "name": name - })["createMalPredictionImportForModelRunDataRows"]) + client.execute( + query_str, + params={ + "dataRowIds": data_row_ids, + "modelRunId": model_run_id, + "projectId": project_id, + "name": name, + }, + )["createMalPredictionImportForModelRunDataRows"], + ) @classmethod - def from_name(cls, - client: "labelbox.Client", - project_id: str, - name: str, - as_json: bool = False) -> "MEAToMALPredictionImport": + def from_name( + cls, + client: "labelbox.Client", + project_id: str, + name: str, + as_json: bool = False, + ) -> "MEAToMALPredictionImport": """ Retrieves an MEA to MAL import job. @@ -505,7 +561,8 @@ def from_name(cls, response = client.execute(query_str, params) if response is None: raise labelbox.exceptions.ResourceNotFoundError( - MALPredictionImport, params) + MALPredictionImport, params + ) response = response["meaToMalPredictionImport"] if as_json: return response @@ -534,8 +591,9 @@ def parent_id(self) -> str: return self.project().uid @classmethod - def create_from_file(cls, client: "labelbox.Client", project_id: str, - name: str, path: str) -> "MALPredictionImport": + def create_from_file( + cls, client: "labelbox.Client", project_id: str, name: str, path: str + ) -> "MALPredictionImport": """ Create an MAL prediction import job from a file of annotations @@ -548,17 +606,20 @@ def create_from_file(cls, client: "labelbox.Client", project_id: str, MALPredictionImport """ if os.path.exists(path): - with open(path, 'rb') as f: + with open(path, "rb") as f: return cls._create_mal_import_from_bytes( - client, project_id, name, f, - os.stat(path).st_size) + client, project_id, name, f, os.stat(path).st_size + ) else: raise ValueError(f"File {path} is not accessible") @classmethod def create_from_objects( - cls, client: "labelbox.Client", project_id: str, name: str, - predictions: Union[List[Dict[str, Any]], List["Label"]] + cls, + client: "labelbox.Client", + project_id: str, + name: str, + predictions: Union[List[Dict[str, Any]], List["Label"]], ) -> "MALPredictionImport": """ Create an MAL prediction import job from an in memory dictionary @@ -572,22 +633,25 @@ def create_from_objects( MALPredictionImport """ - data = cls._get_ndjson_from_objects(predictions, 'annotations') + data = cls._get_ndjson_from_objects(predictions, "annotations") if len(predictions) > 0 and isinstance(predictions[0], Dict): predictions_dicts = cast(List[Dict[str, Any]], predictions) has_confidence = LabelsConfidencePresenceChecker.check( - predictions_dicts) + predictions_dicts + ) if has_confidence: logger.warning(""" Confidence scores are not supported in MAL Prediction Import. Corresponding confidence score values will be ignored. """) - return cls._create_mal_import_from_bytes(client, project_id, name, data, - len(str(data))) + return cls._create_mal_import_from_bytes( + client, project_id, name, data, len(str(data)) + ) @classmethod - def create_from_url(cls, client: "labelbox.Client", project_id: str, - name: str, url: str) -> "MALPredictionImport": + def create_from_url( + cls, client: "labelbox.Client", project_id: str, name: str, url: str + ) -> "MALPredictionImport": """ Create an MAL prediction import job from a url The url must point to a file containing prediction annotations. @@ -609,17 +673,21 @@ def create_from_url(cls, client: "labelbox.Client", project_id: str, params={ "fileUrl": url, "projectId": project_id, - 'name': name - })["createModelAssistedLabelingPredictionImport"]) + "name": name, + }, + )["createModelAssistedLabelingPredictionImport"], + ) else: raise ValueError(f"Url {url} is not reachable") @classmethod - def from_name(cls, - client: "labelbox.Client", - project_id: str, - name: str, - as_json: bool = False) -> "MALPredictionImport": + def from_name( + cls, + client: "labelbox.Client", + project_id: str, + name: str, + as_json: bool = False, + ) -> "MALPredictionImport": """ Retrieves an MAL import job. @@ -642,7 +710,8 @@ def from_name(cls, response = client.execute(query_str, params) if response is None: raise labelbox.exceptions.ResourceNotFoundError( - MALPredictionImport, params) + MALPredictionImport, params + ) response = response["modelAssistedLabelingPredictionImport"] if as_json: return response @@ -668,18 +737,24 @@ def _get_file_mutation(cls) -> str: @classmethod def _create_mal_import_from_bytes( - cls, client: "labelbox.Client", project_id: str, name: str, - bytes_data: BinaryIO, content_len: int) -> "MALPredictionImport": + cls, + client: "labelbox.Client", + project_id: str, + name: str, + bytes_data: BinaryIO, + content_len: int, + ) -> "MALPredictionImport": file_name = f"{project_id}__{name}.ndjson" variables = { "file": None, "contentLength": content_len, "projectId": project_id, - "name": name + "name": name, } query_str = cls._get_file_mutation() - res = cls._create_from_bytes(client, variables, query_str, file_name, - bytes_data) + res = cls._create_from_bytes( + client, variables, query_str, file_name, bytes_data + ) return cls(client, res["createModelAssistedLabelingPredictionImport"]) @@ -694,8 +769,9 @@ def parent_id(self) -> str: return self.project().uid @classmethod - def create_from_file(cls, client: "labelbox.Client", project_id: str, - name: str, path: str) -> "LabelImport": + def create_from_file( + cls, client: "labelbox.Client", project_id: str, name: str, path: str + ) -> "LabelImport": """ Create a label import job from a file of annotations @@ -708,18 +784,21 @@ def create_from_file(cls, client: "labelbox.Client", project_id: str, LabelImport """ if os.path.exists(path): - with open(path, 'rb') as f: + with open(path, "rb") as f: return cls._create_label_import_from_bytes( - client, project_id, name, f, - os.stat(path).st_size) + client, project_id, name, f, os.stat(path).st_size + ) else: raise ValueError(f"File {path} is not accessible") @classmethod def create_from_objects( - cls, client: "labelbox.Client", project_id: str, name: str, - labels: Union[List[Dict[str, Any]], - List["Label"]]) -> "LabelImport": + cls, + client: "labelbox.Client", + project_id: str, + name: str, + labels: Union[List[Dict[str, Any]], List["Label"]], + ) -> "LabelImport": """ Create a label import job from an in memory dictionary @@ -731,7 +810,7 @@ def create_from_objects( Returns: LabelImport """ - data = cls._get_ndjson_from_objects(labels, 'labels') + data = cls._get_ndjson_from_objects(labels, "labels") if len(labels) > 0 and isinstance(labels[0], Dict): label_dicts = cast(List[Dict[str, Any]], labels) @@ -741,12 +820,14 @@ def create_from_objects( Confidence scores are not supported in Label Import. Corresponding confidence score values will be ignored. """) - return cls._create_label_import_from_bytes(client, project_id, name, - data, len(str(data))) + return cls._create_label_import_from_bytes( + client, project_id, name, data, len(str(data)) + ) @classmethod - def create_from_url(cls, client: "labelbox.Client", project_id: str, - name: str, url: str) -> "LabelImport": + def create_from_url( + cls, client: "labelbox.Client", project_id: str, name: str, url: str + ) -> "LabelImport": """ Create a label annotation import job from a url The url must point to a file containing label annotations. @@ -763,21 +844,26 @@ def create_from_url(cls, client: "labelbox.Client", project_id: str, query_str = cls._get_url_mutation() return cls( client, - client.execute(query_str, - params={ - "fileUrl": url, - "projectId": project_id, - 'name': name - })["createLabelImport"]) + client.execute( + query_str, + params={ + "fileUrl": url, + "projectId": project_id, + "name": name, + }, + )["createLabelImport"], + ) else: raise ValueError(f"Url {url} is not reachable") @classmethod - def from_name(cls, - client: "labelbox.Client", - project_id: str, - name: str, - as_json: bool = False) -> "LabelImport": + def from_name( + cls, + client: "labelbox.Client", + project_id: str, + name: str, + as_json: bool = False, + ) -> "LabelImport": """ Retrieves an label import job. @@ -824,18 +910,23 @@ def _get_file_mutation(cls) -> str: }""" % query.results_query_part(cls) @classmethod - def _create_label_import_from_bytes(cls, client: "labelbox.Client", - project_id: str, name: str, - bytes_data: BinaryIO, - content_len: int) -> "LabelImport": + def _create_label_import_from_bytes( + cls, + client: "labelbox.Client", + project_id: str, + name: str, + bytes_data: BinaryIO, + content_len: int, + ) -> "LabelImport": file_name = f"{project_id}__{name}.ndjson" variables = { "file": None, "contentLength": content_len, "projectId": project_id, - "name": name + "name": name, } query_str = cls._get_file_mutation() - res = cls._create_from_bytes(client, variables, query_str, file_name, - bytes_data) + res = cls._create_from_bytes( + client, variables, query_str, file_name, bytes_data + ) return cls(client, res["createLabelImport"]) diff --git a/libs/labelbox/src/labelbox/schema/asset_attachment.py b/libs/labelbox/src/labelbox/schema/asset_attachment.py index fba542011..0d5598c84 100644 --- a/libs/labelbox/src/labelbox/schema/asset_attachment.py +++ b/libs/labelbox/src/labelbox/schema/asset_attachment.py @@ -7,12 +7,12 @@ class AttachmentType(str, Enum): - @classmethod def __missing__(cls, value: object): if str(value) == "TEXT": warnings.warn( - "The TEXT attachment type is deprecated. Use RAW_TEXT instead.") + "The TEXT attachment type is deprecated. Use RAW_TEXT instead." + ) return cls.RAW_TEXT return value @@ -44,13 +44,13 @@ class AssetAttachment(DbObject): @classmethod def validate_attachment_json(cls, attachment_json: Dict[str, str]) -> None: - for required_key in ['type', 'value']: + for required_key in ["type", "value"]: if required_key not in attachment_json: raise ValueError( f"Must provide a `{required_key}` key for each attachment. Found {attachment_json}." ) - cls.validate_attachment_value(attachment_json['value']) - cls.validate_attachment_type(attachment_json['type']) + cls.validate_attachment_value(attachment_json["value"]) + cls.validate_attachment_type(attachment_json["type"]) @classmethod def validate_attachment_value(cls, attachment_value: str) -> None: @@ -75,10 +75,12 @@ def delete(self) -> None: }""" self.client.execute(query_str, {"attachment_id": self.uid}) - def update(self, - name: Optional[str] = None, - type: Optional[str] = None, - value: Optional[str] = None): + def update( + self, + name: Optional[str] = None, + type: Optional[str] = None, + value: Optional[str] = None, + ): """Updates an attachment on the data row.""" if not name and not type and value is None: raise ValueError( @@ -101,9 +103,10 @@ def update(self, data: {name: $name, type: $type, value: $value} ) { id name type value } }""" - res = (self.client.execute(query_str, - query_params))['updateDataRowAttachment'] + res = (self.client.execute(query_str, query_params))[ + "updateDataRowAttachment" + ] - self.attachment_name = res['name'] - self.attachment_value = res['value'] - self.attachment_type = res['type'] + self.attachment_name = res["name"] + self.attachment_value = res["value"] + self.attachment_type = res["type"] diff --git a/libs/labelbox/src/labelbox/schema/batch.py b/libs/labelbox/src/labelbox/schema/batch.py index 313d02c16..7566a73f6 100644 --- a/libs/labelbox/src/labelbox/schema/batch.py +++ b/libs/labelbox/src/labelbox/schema/batch.py @@ -18,7 +18,7 @@ class Batch(DbObject): - """ A Batch is a group of data rows submitted to a project for labeling + """A Batch is a group of data rows submitted to a project for labeling Attributes: name (str) @@ -30,6 +30,7 @@ class Batch(DbObject): created_by (Relationship): `ToOne` relationship to User """ + name = Field.String("name") created_at = Field.DateTime("created_at") updated_at = Field.DateTime("updated_at") @@ -39,18 +40,15 @@ class Batch(DbObject): # Relationships created_by = Relationship.ToOne("User") - def __init__(self, - client, - project_id, - *args, - failed_data_row_ids=[], - **kwargs): + def __init__( + self, client, project_id, *args, failed_data_row_ids=[], **kwargs + ): super().__init__(client, *args, **kwargs) self.project_id = project_id self._failed_data_row_ids = failed_data_row_ids - def project(self) -> 'Project': # type: ignore - """ Returns Project which this Batch belongs to + def project(self) -> "Project": # type: ignore + """Returns Project which this Batch belongs to Raises: LabelboxError: if the project is not found @@ -69,7 +67,7 @@ def project(self) -> 'Project': # type: ignore return Entity.Project(self.client, response["project"]) def remove_queued_data_rows(self) -> None: - """ Removes remaining queued data rows from the batch and labeling queue. + """Removes remaining queued data rows from the batch and labeling queue. Args: batch (Batch): Batch to remove queued data rows from @@ -80,17 +78,21 @@ def remove_queued_data_rows(self) -> None: self.client.execute( """mutation RemoveQueuedDataRowsFromBatchPyApi($%s: ID!, $%s: ID!) { project(where: {id: $%s}) { removeQueuedDataRowsFromBatch(batchId: $%s) { id } } - }""" % (project_id_param, batch_id_param, project_id_param, - batch_id_param), { - project_id_param: self.project_id, - batch_id_param: self.uid - }, - experimental=True) - - def export_data_rows(self, - timeout_seconds=120, - include_metadata: bool = False) -> Generator: - """ Returns a generator that produces all data rows that are currently + }""" + % ( + project_id_param, + batch_id_param, + project_id_param, + batch_id_param, + ), + {project_id_param: self.project_id, batch_id_param: self.uid}, + experimental=True, + ) + + def export_data_rows( + self, timeout_seconds=120, include_metadata: bool = False + ) -> Generator: + """Returns a generator that produces all data rows that are currently in this batch. Note: For efficiency, the data are cached for 30 minutes. Newly created data rows will not appear @@ -106,7 +108,8 @@ def export_data_rows(self, """ warnings.warn( "You are currently utilizing exports v1 for this action, which will be deprecated after April 30th, 2024. We recommend transitioning to exports v2. To view export v2 details, visit our docs: https://docs.labelbox.com/reference/label-export", - DeprecationWarning) + DeprecationWarning, + ) id_param = "batchId" metadata_param = "includeMetadataInput" @@ -115,10 +118,10 @@ def export_data_rows(self, """ % (id_param, metadata_param, id_param, metadata_param) sleep_time = 2 while True: - res = self.client.execute(query_str, { - id_param: self.uid, - metadata_param: include_metadata - }) + res = self.client.execute( + query_str, + {id_param: self.uid, metadata_param: include_metadata}, + ) res = res["exportBatchDataRows"] if res["status"] == "COMPLETE": download_url = res["downloadUrl"] @@ -126,7 +129,8 @@ def export_data_rows(self, response.raise_for_status() reader = parser.reader(StringIO(response.text)) return ( - Entity.DataRow(self.client, result) for result in reader) + Entity.DataRow(self.client, result) for result in reader + ) elif res["status"] == "FAILED": raise LabelboxError("Data row export failed.") @@ -136,14 +140,15 @@ def export_data_rows(self, f"Unable to export data rows within {timeout_seconds} seconds." ) - logger.debug("Batch '%s' data row export, waiting for server...", - self.uid) + logger.debug( + "Batch '%s' data row export, waiting for server...", self.uid + ) time.sleep(sleep_time) def delete(self) -> None: - """ Deletes the given batch. + """Deletes the given batch. - Note: Batch deletion for batches that has labels is forbidden. + Note: Batch deletion for batches that has labels is forbidden. Args: batch (Batch): Batch to remove queued data rows from @@ -151,17 +156,22 @@ def delete(self) -> None: project_id_param = "projectId" batch_id_param = "batchId" - self.client.execute("""mutation DeleteBatchPyApi($%s: ID!, $%s: ID!) { + self.client.execute( + """mutation DeleteBatchPyApi($%s: ID!, $%s: ID!) { project(where: {id: $%s}) { deleteBatch(batchId: $%s) { deletedBatchId } } - }""" % (project_id_param, batch_id_param, project_id_param, - batch_id_param), { - project_id_param: self.project_id, - batch_id_param: self.uid - }, - experimental=True) + }""" + % ( + project_id_param, + batch_id_param, + project_id_param, + batch_id_param, + ), + {project_id_param: self.project_id, batch_id_param: self.uid}, + experimental=True, + ) def delete_labels(self, set_labels_as_template=False) -> None: - """ Deletes labels that were created for data rows in the batch. + """Deletes labels that were created for data rows in the batch. Args: batch (Batch): Batch to remove queued data rows from @@ -174,17 +184,24 @@ def delete_labels(self, set_labels_as_template=False) -> None: res = self.client.execute( """mutation DeleteBatchLabelsPyApi($%s: ID!, $%s: ID!, $%s: DeleteBatchLabelsType!) { project(where: {id: $%s}) { deleteBatchLabels(batchId: $%s, data:{ type: $%s }) { deletedLabelIds } } - }""" % (project_id_param, batch_id_param, type_param, project_id_param, - batch_id_param, type_param), { - project_id_param: - self.project_id, - batch_id_param: - self.uid, - type_param: - "RequeueDataWithLabelAsTemplate" - if set_labels_as_template else "RequeueData" - }, - experimental=True) + }""" + % ( + project_id_param, + batch_id_param, + type_param, + project_id_param, + batch_id_param, + type_param, + ), + { + project_id_param: self.project_id, + batch_id_param: self.uid, + type_param: "RequeueDataWithLabelAsTemplate" + if set_labels_as_template + else "RequeueData", + }, + experimental=True, + ) return res # modify this function to return an empty list if there are no failed data rows diff --git a/libs/labelbox/src/labelbox/schema/benchmark.py b/libs/labelbox/src/labelbox/schema/benchmark.py index 69cfc2f7f..586530e3c 100644 --- a/libs/labelbox/src/labelbox/schema/benchmark.py +++ b/libs/labelbox/src/labelbox/schema/benchmark.py @@ -3,7 +3,7 @@ class Benchmark(DbObject): - """ Represents a benchmark label. + """Represents a benchmark label. The Benchmarks tool works by interspersing data to be labeled, for which there is a benchmark label, to each person labeling. These @@ -19,6 +19,7 @@ class Benchmark(DbObject): created_by (Relationship): `ToOne` relationship to User reference_label (Relationship): `ToOne` relationship to Label """ + created_at = Field.DateTime("created_at") created_by = Relationship.ToOne("User", False, "created_by") last_activity = Field.DateTime("last_activity") @@ -30,7 +31,10 @@ class Benchmark(DbObject): def delete(self) -> None: label_param = "labelId" query_str = """mutation DeleteBenchmarkPyApi($%s: ID!) { - deleteBenchmark(where: {labelId: $%s}) {id}} """ % (label_param, - label_param) - self.client.execute(query_str, - {label_param: self.reference_label().uid}) + deleteBenchmark(where: {labelId: $%s}) {id}} """ % ( + label_param, + label_param, + ) + self.client.execute( + query_str, {label_param: self.reference_label().uid} + ) diff --git a/libs/labelbox/src/labelbox/schema/bulk_import_request.py b/libs/labelbox/src/labelbox/schema/bulk_import_request.py index 6e65aab58..7caa2c6eb 100644 --- a/libs/labelbox/src/labelbox/schema/bulk_import_request.py +++ b/libs/labelbox/src/labelbox/schema/bulk_import_request.py @@ -8,10 +8,29 @@ from google.api_core import retry from labelbox import parser import requests -from pydantic import ValidationError, BaseModel, Field, field_validator, model_validator, ConfigDict, StringConstraints +from pydantic import ( + ValidationError, + BaseModel, + Field, + field_validator, + model_validator, + ConfigDict, + StringConstraints, +) from typing_extensions import Literal, Annotated -from typing import (Any, List, Optional, BinaryIO, Dict, Iterable, Tuple, Union, - Type, Set, TYPE_CHECKING) +from typing import ( + Any, + List, + Optional, + BinaryIO, + Dict, + Iterable, + Tuple, + Union, + Type, + Set, + TYPE_CHECKING, +) from labelbox import exceptions as lb_exceptions from labelbox import utils @@ -29,11 +48,13 @@ NDJSON_MIME_TYPE = "application/x-ndjson" logger = logging.getLogger(__name__) -#TODO: Deprecate this library in place of labelimport and malprediction import library. +# TODO: Deprecate this library in place of labelimport and malprediction import library. + def _determinants(parent_cls: Any) -> List[str]: return [ - k for k, v in parent_cls.model_fields.items() + k + for k, v in parent_cls.model_fields.items() if v.json_schema_extra and "determinant" in v.json_schema_extra ] @@ -43,8 +64,9 @@ def _make_file_name(project_id: str, name: str) -> str: # TODO(gszpak): move it to client.py -def _make_request_data(project_id: str, name: str, content_length: int, - file_name: str) -> dict: +def _make_request_data( + project_id: str, name: str, content_length: int, file_name: str +) -> dict: query_str = """mutation createBulkImportRequestFromFilePyApi( $projectId: ID!, $name: String!, $file: Upload!, $contentLength: Int!) { createBulkImportRequest(data: { @@ -63,26 +85,30 @@ def _make_request_data(project_id: str, name: str, content_length: int, "projectId": project_id, "name": name, "file": None, - "contentLength": content_length + "contentLength": content_length, } operations = json.dumps({"variables": variables, "query": query_str}) return { "operations": operations, - "map": (None, json.dumps({file_name: ["variables.file"]})) + "map": (None, json.dumps({file_name: ["variables.file"]})), } def _send_create_file_command( - client, request_data: dict, file_name: str, - file_data: Tuple[str, Union[bytes, BinaryIO], str]) -> dict: - + client, + request_data: dict, + file_name: str, + file_data: Tuple[str, Union[bytes, BinaryIO], str], +) -> dict: response = client.execute(data=request_data, files={file_name: file_data}) if not response.get("createBulkImportRequest", None): raise lb_exceptions.LabelboxError( - "Failed to create BulkImportRequest, message: %s" % - response.get("errors", None) or response.get("error", None)) + "Failed to create BulkImportRequest, message: %s" + % response.get("errors", None) + or response.get("error", None) + ) return response @@ -101,6 +127,7 @@ class BulkImportRequest(DbObject): project (Relationship): `ToOne` relationship to Project created_by (Relationship): `ToOne` relationship to User """ + name = lb_Field.String("name") state = lb_Field.Enum(BulkImportRequestState, "state") input_file_url = lb_Field.String("input_file_url") @@ -182,8 +209,7 @@ def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]: return parser.loads(response.text) def refresh(self) -> None: - """Synchronizes values of all fields with the database. - """ + """Synchronizes values of all fields with the database.""" query_str, params = query.get_single(BulkImportRequest, self.uid) res = self.client.execute(query_str, params) res = res[utils.camel_case(BulkImportRequest.type_name())] @@ -207,16 +233,21 @@ def wait_until_done(self, sleep_time_seconds: int = 5) -> None: time.sleep(sleep_time_seconds) self.__exponential_backoff_refresh() - @retry.Retry(predicate=retry.if_exception_type(lb_exceptions.ApiLimitError, - lb_exceptions.TimeoutError, - lb_exceptions.NetworkError)) + @retry.Retry( + predicate=retry.if_exception_type( + lb_exceptions.ApiLimitError, + lb_exceptions.TimeoutError, + lb_exceptions.NetworkError, + ) + ) def __exponential_backoff_refresh(self) -> None: self.refresh() @classmethod - def from_name(cls, client, project_id: str, - name: str) -> 'BulkImportRequest': - """ Fetches existing BulkImportRequest. + def from_name( + cls, client, project_id: str, name: str + ) -> "BulkImportRequest": + """Fetches existing BulkImportRequest. Args: client (Client): a Labelbox client @@ -238,15 +269,12 @@ def from_name(cls, client, project_id: str, """ % query.results_query_part(cls) params = {"projectId": project_id, "name": name} response = client.execute(query_str, params=params) - return cls(client, response['bulkImportRequest']) + return cls(client, response["bulkImportRequest"]) @classmethod - def create_from_url(cls, - client, - project_id: str, - name: str, - url: str, - validate=True) -> 'BulkImportRequest': + def create_from_url( + cls, client, project_id: str, name: str, url: str, validate=True + ) -> "BulkImportRequest": """ Creates a BulkImportRequest from a publicly accessible URL to an ndjson file with predictions. @@ -282,17 +310,19 @@ def create_from_url(cls, """ % query.results_query_part(cls) params = {"projectId": project_id, "name": name, "fileUrl": url} bulk_import_request_response = client.execute(query_str, params=params) - return cls(client, - bulk_import_request_response["createBulkImportRequest"]) + return cls( + client, bulk_import_request_response["createBulkImportRequest"] + ) @classmethod - def create_from_objects(cls, - client, - project_id: str, - name: str, - predictions: Union[Iterable[Dict], - Iterable["Label"]], - validate=True) -> 'BulkImportRequest': + def create_from_objects( + cls, + client, + project_id: str, + name: str, + predictions: Union[Iterable[Dict], Iterable["Label"]], + validate=True, + ) -> "BulkImportRequest": """ Creates a `BulkImportRequest` from an iterable of dictionaries. @@ -332,27 +362,27 @@ def create_from_objects(cls, data_str = parser.dumps(ndjson_predictions) if not data_str: - raise ValueError('annotations cannot be empty') + raise ValueError("annotations cannot be empty") - data = data_str.encode('utf-8') + data = data_str.encode("utf-8") file_name = _make_file_name(project_id, name) - request_data = _make_request_data(project_id, name, len(data_str), - file_name) + request_data = _make_request_data( + project_id, name, len(data_str), file_name + ) file_data = (file_name, data, NDJSON_MIME_TYPE) - response_data = _send_create_file_command(client, - request_data=request_data, - file_name=file_name, - file_data=file_data) + response_data = _send_create_file_command( + client, + request_data=request_data, + file_name=file_name, + file_data=file_data, + ) return cls(client, response_data["createBulkImportRequest"]) @classmethod - def create_from_local_file(cls, - client, - project_id: str, - name: str, - file: Path, - validate_file=True) -> 'BulkImportRequest': + def create_from_local_file( + cls, client, project_id: str, name: str, file: Path, validate_file=True + ) -> "BulkImportRequest": """ Creates a BulkImportRequest from a local ndjson file with predictions. @@ -369,10 +399,11 @@ def create_from_local_file(cls, """ file_name = _make_file_name(project_id, name) content_length = file.stat().st_size - request_data = _make_request_data(project_id, name, content_length, - file_name) + request_data = _make_request_data( + project_id, name, content_length, file_name + ) - with file.open('rb') as f: + with file.open("rb") as f: if validate_file: reader = parser.reader(f) # ensure that the underlying json load call is valid @@ -386,12 +417,13 @@ def create_from_local_file(cls, else: f.seek(0) file_data = (file.name, f, NDJSON_MIME_TYPE) - response_data = _send_create_file_command(client, request_data, - file_name, file_data) + response_data = _send_create_file_command( + client, request_data, file_name, file_data + ) return cls(client, response_data["createBulkImportRequest"]) def delete(self) -> None: - """ Deletes the import job and also any annotations created by this import. + """Deletes the import job and also any annotations created by this import. Returns: None @@ -406,8 +438,9 @@ def delete(self) -> None: self.client.execute(query_str, {id_param: self.uid}) -def _validate_ndjson(lines: Iterable[Dict[str, Any]], - project: "Project") -> None: +def _validate_ndjson( + lines: Iterable[Dict[str, Any]], project: "Project" +) -> None: """ Client side validation of an ndjson object. @@ -426,26 +459,29 @@ def _validate_ndjson(lines: Iterable[Dict[str, Any]], UuidError: Duplicate UUID in upload """ feature_schemas_by_id, feature_schemas_by_name = get_mal_schemas( - project.ontology()) + project.ontology() + ) uids: Set[str] = set() for idx, line in enumerate(lines): try: annotation = NDAnnotation(**line) - annotation.validate_instance(feature_schemas_by_id, - feature_schemas_by_name) + annotation.validate_instance( + feature_schemas_by_id, feature_schemas_by_name + ) uuid = str(annotation.uuid) if uuid in uids: raise lb_exceptions.UuidError( - f'{uuid} already used in this import job, ' - 'must be unique for the project.') + f"{uuid} already used in this import job, " + "must be unique for the project." + ) uids.add(uuid) - except (ValidationError, ValueError, TypeError, - KeyError) as e: + except (ValidationError, ValueError, TypeError, KeyError) as e: raise lb_exceptions.MALValidationError( - f"Invalid NDJson on line {idx}") from e + f"Invalid NDJson on line {idx}" + ) from e -#The rest of this file contains objects for MAL validation +# The rest of this file contains objects for MAL validation def parse_classification(tool): """ Parses a classification from an ontology. Only radio, checklist, and text are supported for mal @@ -456,20 +492,20 @@ def parse_classification(tool): Returns: dict """ - if tool['type'] in ['radio', 'checklist']: - option_schema_ids = [r['featureSchemaId'] for r in tool['options']] - option_names = [r['value'] for r in tool['options']] + if tool["type"] in ["radio", "checklist"]: + option_schema_ids = [r["featureSchemaId"] for r in tool["options"]] + option_names = [r["value"] for r in tool["options"]] return { - 'tool': tool['type'], - 'featureSchemaId': tool['featureSchemaId'], - 'name': tool['name'], - 'options': [*option_schema_ids, *option_names] + "tool": tool["type"], + "featureSchemaId": tool["featureSchemaId"], + "name": tool["name"], + "options": [*option_schema_ids, *option_names], } - elif tool['type'] == 'text': + elif tool["type"] == "text": return { - 'tool': tool['type'], - 'name': tool['name'], - 'featureSchemaId': tool['featureSchemaId'] + "tool": tool["type"], + "name": tool["name"], + "featureSchemaId": tool["featureSchemaId"], } @@ -485,31 +521,32 @@ def get_mal_schemas(ontology): valid_feature_schemas_by_schema_id = {} valid_feature_schemas_by_name = {} - for tool in ontology.normalized['tools']: + for tool in ontology.normalized["tools"]: classifications = [ parse_classification(classification_tool) - for classification_tool in tool['classifications'] + for classification_tool in tool["classifications"] ] classifications_by_schema_id = { - v['featureSchemaId']: v for v in classifications + v["featureSchemaId"]: v for v in classifications } - classifications_by_name = {v['name']: v for v in classifications} - valid_feature_schemas_by_schema_id[tool['featureSchemaId']] = { - 'tool': tool['tool'], - 'classificationsBySchemaId': classifications_by_schema_id, - 'classificationsByName': classifications_by_name, - 'name': tool['name'] + classifications_by_name = {v["name"]: v for v in classifications} + valid_feature_schemas_by_schema_id[tool["featureSchemaId"]] = { + "tool": tool["tool"], + "classificationsBySchemaId": classifications_by_schema_id, + "classificationsByName": classifications_by_name, + "name": tool["name"], } - valid_feature_schemas_by_name[tool['name']] = { - 'tool': tool['tool'], - 'classificationsBySchemaId': classifications_by_schema_id, - 'classificationsByName': classifications_by_name, - 'name': tool['name'] + valid_feature_schemas_by_name[tool["name"]] = { + "tool": tool["tool"], + "classificationsBySchemaId": classifications_by_schema_id, + "classificationsByName": classifications_by_name, + "name": tool["name"], } - for tool in ontology.normalized['classifications']: - valid_feature_schemas_by_schema_id[ - tool['featureSchemaId']] = parse_classification(tool) - valid_feature_schemas_by_name[tool['name']] = parse_classification(tool) + for tool in ontology.normalized["classifications"]: + valid_feature_schemas_by_schema_id[tool["featureSchemaId"]] = ( + parse_classification(tool) + ) + valid_feature_schemas_by_name[tool["name"]] = parse_classification(tool) return valid_feature_schemas_by_schema_id, valid_feature_schemas_by_name @@ -531,13 +568,12 @@ class FrameLocation(BaseModel): class VideoSupported(BaseModel): - #Note that frames are only allowed as top level inferences for video + # Note that frames are only allowed as top level inferences for video frames: Optional[List[FrameLocation]] = None # Base class for a special kind of union. class SpecialUnion: - def __new__(cls, **kwargs): return cls.build(kwargs) @@ -553,7 +589,8 @@ def get_union_types(cls): union_types = [x for x in cls.__orig_bases__ if hasattr(x, "__args__")] if len(union_types) < 1: raise TypeError( - "Class {cls} should inherit from a union of objects to build") + "Class {cls} should inherit from a union of objects to build" + ) if len(union_types) > 1: raise TypeError( f"Class {cls} should inherit from exactly one union of objects to build. Found {union_types}" @@ -561,15 +598,14 @@ def get_union_types(cls): return union_types[0].__args__[0].__args__ @classmethod - def build(cls: Any, data: Union[dict, - BaseModel]) -> "NDBase": + def build(cls: Any, data: Union[dict, BaseModel]) -> "NDBase": """ - Checks through all objects in the union to see which matches the input data. - Args: - data (Union[dict, BaseModel]) : The data for constructing one of the objects in the union - raises: - KeyError: data does not contain the determinant fields for any of the types supported by this SpecialUnion - ValidationError: Error while trying to construct a specific object in the union + Checks through all objects in the union to see which matches the input data. + Args: + data (Union[dict, BaseModel]) : The data for constructing one of the objects in the union + raises: + KeyError: data does not contain the determinant fields for any of the types supported by this SpecialUnion + ValidationError: Error while trying to construct a specific object in the union """ if isinstance(data, BaseModel): @@ -588,11 +624,11 @@ def build(cls: Any, data: Union[dict, matched = type_ if matched is not None: - #These two have the exact same top level keys + # These two have the exact same top level keys if matched in [NDRadio, NDText]: - if isinstance(data['answer'], dict): + if isinstance(data["answer"], dict): matched = NDRadio - elif isinstance(data['answer'], str): + elif isinstance(data["answer"], str): matched = NDText else: raise TypeError( @@ -606,10 +642,10 @@ def build(cls: Any, data: Union[dict, @classmethod def schema(cls): - results = {'definitions': {}} + results = {"definitions": {}} for cl in cls.get_union_types(): schema = cl.schema() - results['definitions'].update(schema.pop('definitions')) + results["definitions"].update(schema.pop("definitions")) results[cl.__name__] = schema return results @@ -626,7 +662,8 @@ class NDFeatureSchema(BaseModel): def most_set_one(self): if self.schemaId is None and self.name is None: raise ValueError( - "Must set either schemaId or name for all feature schemas") + "Must set either schemaId or name for all feature schemas" + ) return self @@ -636,16 +673,19 @@ class NDBase(NDFeatureSchema): dataRow: DataRow model_config = ConfigDict(extra="forbid") - def validate_feature_schemas(self, valid_feature_schemas_by_id, - valid_feature_schemas_by_name): + def validate_feature_schemas( + self, valid_feature_schemas_by_id, valid_feature_schemas_by_name + ): if self.name: if self.name not in valid_feature_schemas_by_name: raise ValueError( f"Name {self.name} is not valid for the provided project's ontology." ) - if self.ontology_type != valid_feature_schemas_by_name[ - self.name]['tool']: + if ( + self.ontology_type + != valid_feature_schemas_by_name[self.name]["tool"] + ): raise ValueError( f"Name {self.name} does not map to the assigned tool {valid_feature_schemas_by_name[self.name]['tool']}" ) @@ -656,16 +696,20 @@ def validate_feature_schemas(self, valid_feature_schemas_by_id, f"Schema id {self.schemaId} is not valid for the provided project's ontology." ) - if self.ontology_type != valid_feature_schemas_by_id[ - self.schemaId]['tool']: + if ( + self.ontology_type + != valid_feature_schemas_by_id[self.schemaId]["tool"] + ): raise ValueError( f"Schema id {self.schemaId} does not map to the assigned tool {valid_feature_schemas_by_id[self.schemaId]['tool']}" ) - def validate_instance(self, valid_feature_schemas_by_id, - valid_feature_schemas_by_name): - self.validate_feature_schemas(valid_feature_schemas_by_id, - valid_feature_schemas_by_name) + def validate_instance( + self, valid_feature_schemas_by_id, valid_feature_schemas_by_name + ): + self.validate_feature_schemas( + valid_feature_schemas_by_id, valid_feature_schemas_by_name + ) ###### Classifications ###### @@ -674,36 +718,42 @@ def validate_instance(self, valid_feature_schemas_by_id, class NDText(NDBase): ontology_type: Literal["text"] = "text" answer: str = Field(json_schema_extra={"determinant": True}) - #No feature schema to check + # No feature schema to check class NDChecklist(VideoSupported, NDBase): ontology_type: Literal["checklist"] = "checklist" - answers: List[NDFeatureSchema] = Field(json_schema_extra={"determinant": True}) + answers: List[NDFeatureSchema] = Field( + json_schema_extra={"determinant": True} + ) - @field_validator('answers', mode="before") + @field_validator("answers", mode="before") def validate_answers(cls, value, field): - #constr not working with mypy. + # constr not working with mypy. if not len(value): raise ValueError("Checklist answers should not be empty") return value - def validate_feature_schemas(self, valid_feature_schemas_by_id, - valid_feature_schemas_by_name): - #Test top level feature schema for this tool - super(NDChecklist, - self).validate_feature_schemas(valid_feature_schemas_by_id, - valid_feature_schemas_by_name) - #Test the feature schemas provided to the answer field - if len(set([answer.name or answer.schemaId for answer in self.answers - ])) != len(self.answers): + def validate_feature_schemas( + self, valid_feature_schemas_by_id, valid_feature_schemas_by_name + ): + # Test top level feature schema for this tool + super(NDChecklist, self).validate_feature_schemas( + valid_feature_schemas_by_id, valid_feature_schemas_by_name + ) + # Test the feature schemas provided to the answer field + if len( + set([answer.name or answer.schemaId for answer in self.answers]) + ) != len(self.answers): raise ValueError( - f"Duplicated featureSchema found for checklist {self.uuid}") + f"Duplicated featureSchema found for checklist {self.uuid}" + ) for answer in self.answers: - options = valid_feature_schemas_by_name[ - self. - name]['options'] if self.name else valid_feature_schemas_by_id[ - self.schemaId]['options'] + options = ( + valid_feature_schemas_by_name[self.name]["options"] + if self.name + else valid_feature_schemas_by_id[self.schemaId]["options"] + ) if answer.name not in options and answer.schemaId not in options: raise ValueError( f"Feature schema provided to {self.ontology_type} invalid. Expected on of {options}. Found {answer}" @@ -714,26 +764,35 @@ class NDRadio(VideoSupported, NDBase): ontology_type: Literal["radio"] = "radio" answer: NDFeatureSchema = Field(json_schema_extra={"determinant": True}) - def validate_feature_schemas(self, valid_feature_schemas_by_id, - valid_feature_schemas_by_name): - super(NDRadio, - self).validate_feature_schemas(valid_feature_schemas_by_id, - valid_feature_schemas_by_name) - options = valid_feature_schemas_by_name[ - self.name]['options'] if self.name else valid_feature_schemas_by_id[ - self.schemaId]['options'] - if self.answer.name not in options and self.answer.schemaId not in options: + def validate_feature_schemas( + self, valid_feature_schemas_by_id, valid_feature_schemas_by_name + ): + super(NDRadio, self).validate_feature_schemas( + valid_feature_schemas_by_id, valid_feature_schemas_by_name + ) + options = ( + valid_feature_schemas_by_name[self.name]["options"] + if self.name + else valid_feature_schemas_by_id[self.schemaId]["options"] + ) + if ( + self.answer.name not in options + and self.answer.schemaId not in options + ): raise ValueError( f"Feature schema provided to {self.ontology_type} invalid. Expected on of {options}. Found {self.answer.name or self.answer.schemaId}" ) -#A union with custom construction logic to improve error messages +# A union with custom construction logic to improve error messages class NDClassification( - SpecialUnion, - Type[Union[ # type: ignore - NDText, NDRadio, NDChecklist]]): - ... + SpecialUnion, + Type[ + Union[ # type: ignore + NDText, NDRadio, NDChecklist + ] + ], +): ... ###### Tools ###### @@ -742,35 +801,41 @@ class NDClassification( class NDBaseTool(NDBase): classifications: List[NDClassification] = [] - #This is indepdent of our problem - def validate_feature_schemas(self, valid_feature_schemas_by_id, - valid_feature_schemas_by_name): - super(NDBaseTool, - self).validate_feature_schemas(valid_feature_schemas_by_id, - valid_feature_schemas_by_name) + # This is indepdent of our problem + def validate_feature_schemas( + self, valid_feature_schemas_by_id, valid_feature_schemas_by_name + ): + super(NDBaseTool, self).validate_feature_schemas( + valid_feature_schemas_by_id, valid_feature_schemas_by_name + ) for classification in self.classifications: classification.validate_feature_schemas( - valid_feature_schemas_by_name[ - self.name]['classificationsBySchemaId'] - if self.name else valid_feature_schemas_by_id[self.schemaId] - ['classificationsBySchemaId'], valid_feature_schemas_by_name[ - self.name]['classificationsByName'] - if self.name else valid_feature_schemas_by_id[ - self.schemaId]['classificationsByName']) - - @field_validator('classifications', mode="before") + valid_feature_schemas_by_name[self.name][ + "classificationsBySchemaId" + ] + if self.name + else valid_feature_schemas_by_id[self.schemaId][ + "classificationsBySchemaId" + ], + valid_feature_schemas_by_name[self.name][ + "classificationsByName" + ] + if self.name + else valid_feature_schemas_by_id[self.schemaId][ + "classificationsByName" + ], + ) + + @field_validator("classifications", mode="before") def validate_subclasses(cls, value, field): - #Create uuid and datarow id so we don't have to define classification objects twice - #This is caused by the fact that we require these ids for top level classifications but not for subclasses + # Create uuid and datarow id so we don't have to define classification objects twice + # This is caused by the fact that we require these ids for top level classifications but not for subclasses results = [] - dummy_id = 'child'.center(25, '_') + dummy_id = "child".center(25, "_") for row in value: - results.append({ - **row, 'dataRow': { - 'id': dummy_id - }, - 'uuid': str(uuid4()) - }) + results.append( + {**row, "dataRow": {"id": dummy_id}, "uuid": str(uuid4())} + ) return results @@ -778,11 +843,12 @@ class NDPolygon(NDBaseTool): ontology_type: Literal["polygon"] = "polygon" polygon: List[Point] = Field(json_schema_extra={"determinant": True}) - @field_validator('polygon') + @field_validator("polygon") def is_geom_valid(cls, v): if len(v) < 3: raise ValueError( - f"A polygon must have at least 3 points to be valid. Found {v}") + f"A polygon must have at least 3 points to be valid. Found {v}" + ) return v @@ -790,24 +856,25 @@ class NDPolyline(NDBaseTool): ontology_type: Literal["line"] = "line" line: List[Point] = Field(json_schema_extra={"determinant": True}) - @field_validator('line') + @field_validator("line") def is_geom_valid(cls, v): if len(v) < 2: raise ValueError( - f"A line must have at least 2 points to be valid. Found {v}") + f"A line must have at least 2 points to be valid. Found {v}" + ) return v class NDRectangle(NDBaseTool): ontology_type: Literal["rectangle"] = "rectangle" bbox: Bbox = Field(json_schema_extra={"determinant": True}) - #Could check if points are positive + # Could check if points are positive class NDPoint(NDBaseTool): ontology_type: Literal["point"] = "point" point: Point = Field(json_schema_extra={"determinant": True}) - #Could check if points are positive + # Could check if points are positive class EntityLocation(BaseModel): @@ -819,17 +886,18 @@ class NDTextEntity(NDBaseTool): ontology_type: Literal["named-entity"] = "named-entity" location: EntityLocation = Field(json_schema_extra={"determinant": True}) - @field_validator('location') + @field_validator("location") def is_valid_location(cls, v): if isinstance(v, BaseModel): v = v.model_dump() if len(v) < 2: raise ValueError( - f"A line must have at least 2 points to be valid. Found {v}") - if v['start'] < 0: + f"A line must have at least 2 points to be valid. Found {v}" + ) + if v["start"] < 0: raise ValueError(f"Text location must be positive. Found {v}") - if v['start'] > v['end']: + if v["start"] > v["end"]: raise ValueError( f"Text start location must be less or equal than end. Found {v}" ) @@ -840,7 +908,7 @@ class RLEMaskFeatures(BaseModel): counts: List[int] size: List[int] - @field_validator('counts') + @field_validator("counts") def validate_counts(cls, counts): if not all([count >= 0 for count in counts]): raise ValueError( @@ -848,7 +916,7 @@ def validate_counts(cls, counts): ) return counts - @field_validator('size') + @field_validator("size") def validate_size(cls, size): if len(size) != 2: raise ValueError( @@ -856,7 +924,8 @@ def validate_size(cls, size): ) if not all([count > 0 for count in size]): raise ValueError( - f"Mask `size` should be a postitive int. Found : {size}") + f"Mask `size` should be a postitive int. Found : {size}" + ) return size @@ -869,9 +938,9 @@ class URIMaskFeatures(BaseModel): instanceURI: str colorRGB: Union[List[int], Tuple[int, int, int]] - @field_validator('colorRGB') + @field_validator("colorRGB") def validate_color(cls, colorRGB): - #Does the dtype matter? Can it be a float? + # Does the dtype matter? Can it be a float? if not isinstance(colorRGB, (tuple, list)): raise ValueError( f"Received color that is not a list or tuple. Found : {colorRGB}" @@ -882,39 +951,46 @@ def validate_color(cls, colorRGB): ) elif not all([0 <= color <= 255 for color in colorRGB]): raise ValueError( - f"All rgb colors must be between 0 and 255. Found : {colorRGB}") + f"All rgb colors must be between 0 and 255. Found : {colorRGB}" + ) return colorRGB class NDMask(NDBaseTool): ontology_type: Literal["superpixel"] = "superpixel" - mask: Union[URIMaskFeatures, PNGMaskFeatures, - RLEMaskFeatures] = Field(json_schema_extra={"determinant": True}) + mask: Union[URIMaskFeatures, PNGMaskFeatures, RLEMaskFeatures] = Field( + json_schema_extra={"determinant": True} + ) -#A union with custom construction logic to improve error messages +# A union with custom construction logic to improve error messages class NDTool( - SpecialUnion, - Type[Union[ # type: ignore + SpecialUnion, + Type[ + Union[ # type: ignore NDMask, NDTextEntity, NDPoint, NDRectangle, NDPolyline, NDPolygon, - ]]): - ... + ] + ], +): ... class NDAnnotation( - SpecialUnion, - Type[Union[ # type: ignore - NDTool, NDClassification]]): - + SpecialUnion, + Type[ + Union[ # type: ignore + NDTool, NDClassification + ] + ], +): @classmethod def build(cls: Any, data) -> "NDBase": if not isinstance(data, dict): - raise ValueError('value must be dict') + raise ValueError("value must be dict") errors = [] for cl in cls.get_union_types(): try: @@ -922,14 +998,15 @@ def build(cls: Any, data) -> "NDBase": except KeyError as e: errors.append(f"{cl.__name__}: {e}") - raise ValueError('Unable to construct any annotation.\n{}'.format( - "\n".join(errors))) + raise ValueError( + "Unable to construct any annotation.\n{}".format("\n".join(errors)) + ) @classmethod def schema(cls): - data = {'definitions': {}} + data = {"definitions": {}} for type_ in cls.get_union_types(): schema_ = type_.schema() - data['definitions'].update(schema_.pop('definitions')) + data["definitions"].update(schema_.pop("definitions")) data[type_.__name__] = schema_ return data diff --git a/libs/labelbox/src/labelbox/schema/catalog.py b/libs/labelbox/src/labelbox/schema/catalog.py index c377703b1..567bbd777 100644 --- a/libs/labelbox/src/labelbox/schema/catalog.py +++ b/libs/labelbox/src/labelbox/schema/catalog.py @@ -2,12 +2,15 @@ from labelbox.orm.db_object import experimental from labelbox.schema.export_filters import CatalogExportFilters, build_filters -from labelbox.schema.export_params import (CatalogExportParams, - validate_catalog_export_params) +from labelbox.schema.export_params import ( + CatalogExportParams, + validate_catalog_export_params, +) from labelbox.schema.export_task import ExportTask from labelbox.schema.task import Task from typing import TYPE_CHECKING + if TYPE_CHECKING: from labelbox import Client @@ -15,7 +18,7 @@ class Catalog: client: "Client" - def __init__(self, client: 'Client'): + def __init__(self, client: "Client"): self.client = client def export_v2( @@ -43,7 +46,7 @@ def export_v2( >>> task.result """ task, is_streamable = self._export(task_name, filters, params) - if (is_streamable): + if is_streamable: return ExportTask(task, True) return task @@ -72,44 +75,49 @@ def export( task, _ = self._export(task_name, filters, params, streamable=True) return ExportTask(task) - def _export(self, - task_name: Optional[str] = None, - filters: Union[CatalogExportFilters, Dict[str, List[str]], - None] = None, - params: Optional[CatalogExportParams] = None, - streamable: bool = False) -> Tuple[Task, bool]: - - _params = params or CatalogExportParams({ - "attachments": False, - "embeddings": False, - "metadata_fields": False, - "data_row_details": False, - "project_details": False, - "performance_details": False, - "label_details": False, - "media_type_override": None, - "model_run_ids": None, - "project_ids": None, - "interpolated_frames": False, - "all_projects": False, - "all_model_runs": False, - }) + def _export( + self, + task_name: Optional[str] = None, + filters: Union[CatalogExportFilters, Dict[str, List[str]], None] = None, + params: Optional[CatalogExportParams] = None, + streamable: bool = False, + ) -> Tuple[Task, bool]: + _params = params or CatalogExportParams( + { + "attachments": False, + "embeddings": False, + "metadata_fields": False, + "data_row_details": False, + "project_details": False, + "performance_details": False, + "label_details": False, + "media_type_override": None, + "model_run_ids": None, + "project_ids": None, + "interpolated_frames": False, + "all_projects": False, + "all_model_runs": False, + } + ) validate_catalog_export_params(_params) - _filters = filters or CatalogExportFilters({ - "last_activity_at": None, - "label_created_at": None, - "data_row_ids": None, - "global_keys": None, - }) + _filters = filters or CatalogExportFilters( + { + "last_activity_at": None, + "label_created_at": None, + "data_row_ids": None, + "global_keys": None, + } + ) mutation_name = "exportDataRowsInCatalog" create_task_query_str = ( f"mutation {mutation_name}PyApi" f"($input: ExportDataRowsInCatalogInput!)" - f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}") + f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}" + ) - media_type_override = _params.get('media_type_override', None) + media_type_override = _params.get("media_type_override", None) query_params: Dict[str, Any] = { "input": { "taskName": task_name, @@ -121,35 +129,30 @@ def _export(self, }, "isStreamableReady": True, "params": { - "mediaTypeOverride": - media_type_override.value - if media_type_override is not None else None, - "includeAttachments": - _params.get('attachments', False), - "includeEmbeddings": - _params.get('embeddings', False), - "includeMetadata": - _params.get('metadata_fields', False), - "includeDataRowDetails": - _params.get('data_row_details', False), - "includeProjectDetails": - _params.get('project_details', False), - "includePerformanceDetails": - _params.get('performance_details', False), - "includeLabelDetails": - _params.get('label_details', False), - "includeInterpolatedFrames": - _params.get('interpolated_frames', False), - "includePredictions": - _params.get('predictions', False), - "projectIds": - _params.get('project_ids', None), - "modelRunIds": - _params.get('model_run_ids', None), - "allProjects": - _params.get('all_projects', False), - "allModelRuns": - _params.get('all_model_runs', False), + "mediaTypeOverride": media_type_override.value + if media_type_override is not None + else None, + "includeAttachments": _params.get("attachments", False), + "includeEmbeddings": _params.get("embeddings", False), + "includeMetadata": _params.get("metadata_fields", False), + "includeDataRowDetails": _params.get( + "data_row_details", False + ), + "includeProjectDetails": _params.get( + "project_details", False + ), + "includePerformanceDetails": _params.get( + "performance_details", False + ), + "includeLabelDetails": _params.get("label_details", False), + "includeInterpolatedFrames": _params.get( + "interpolated_frames", False + ), + "includePredictions": _params.get("predictions", False), + "projectIds": _params.get("project_ids", None), + "modelRunIds": _params.get("model_run_ids", None), + "allProjects": _params.get("all_projects", False), + "allModelRuns": _params.get("all_model_runs", False), }, "streamable": streamable, } @@ -158,9 +161,9 @@ def _export(self, search_query = build_filters(self.client, _filters) query_params["input"]["filters"]["searchQuery"]["query"] = search_query - res = self.client.execute(create_task_query_str, - query_params, - error_log_key="errors") + res = self.client.execute( + create_task_query_str, query_params, error_log_key="errors" + ) res = res[mutation_name] task_id = res["taskId"] is_streamable = res["isStreamable"] diff --git a/libs/labelbox/src/labelbox/schema/confidence_presence_checker.py b/libs/labelbox/src/labelbox/schema/confidence_presence_checker.py index cfdbe0ed3..77d3bfb3f 100644 --- a/libs/labelbox/src/labelbox/schema/confidence_presence_checker.py +++ b/libs/labelbox/src/labelbox/schema/confidence_presence_checker.py @@ -13,8 +13,9 @@ def check(cls, raw_labels: List[Dict[str, Any]]): return len(keys.intersection(set(["confidence"]))) == 1 @classmethod - def _collect_keys_from_list(cls, objects: List[Dict[str, Any]], - keys: Set[str]): + def _collect_keys_from_list( + cls, objects: List[Dict[str, Any]], keys: Set[str] + ): for obj in objects: if isinstance(obj, (list, tuple)): cls._collect_keys_from_list(obj, keys) diff --git a/libs/labelbox/src/labelbox/schema/create_batches_task.py b/libs/labelbox/src/labelbox/schema/create_batches_task.py index eb7b5d150..25ff80917 100644 --- a/libs/labelbox/src/labelbox/schema/create_batches_task.py +++ b/libs/labelbox/src/labelbox/schema/create_batches_task.py @@ -13,9 +13,9 @@ def lru_cache() -> Callable[..., Callable[..., Dict[str, Any]]]: class CreateBatchesTask: - - def __init__(self, client, project_id: str, batch_ids: List[str], - task_ids: List[str]): + def __init__( + self, client, project_id: str, batch_ids: List[str], task_ids: List[str] + ): self.client = client self.project_id = project_id self.batches = batch_ids diff --git a/libs/labelbox/src/labelbox/schema/data_row.py b/libs/labelbox/src/labelbox/schema/data_row.py index b7c9b324d..8987a00f0 100644 --- a/libs/labelbox/src/labelbox/schema/data_row.py +++ b/libs/labelbox/src/labelbox/schema/data_row.py @@ -4,12 +4,24 @@ import json from labelbox.orm import query -from labelbox.orm.db_object import DbObject, Updateable, BulkDeletable, experimental +from labelbox.orm.db_object import ( + DbObject, + Updateable, + BulkDeletable, + experimental, +) from labelbox.orm.model import Entity, Field, Relationship from labelbox.schema.asset_attachment import AttachmentType from labelbox.schema.data_row_metadata import DataRowMetadataField # type: ignore -from labelbox.schema.export_filters import DatarowExportFilters, build_filters, validate_at_least_one_of_data_row_ids_or_global_keys -from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params +from labelbox.schema.export_filters import ( + DatarowExportFilters, + build_filters, + validate_at_least_one_of_data_row_ids_or_global_keys, +) +from labelbox.schema.export_params import ( + CatalogExportParams, + validate_catalog_export_params, +) from labelbox.schema.export_task import ExportTask from labelbox.schema.task import Task @@ -20,16 +32,16 @@ class KeyType(str, Enum): - ID = 'ID' + ID = "ID" """An existing CUID""" - GKEY = 'GKEY' + GKEY = "GKEY" """A Global key, could be existing or non-existing""" - AUTO = 'AUTO' + AUTO = "AUTO" """The key will be auto-generated. Only usable for creates""" class DataRow(DbObject, Updateable, BulkDeletable): - """ Internal Labelbox representation of a single piece of data (e.g. image, video, text). + """Internal Labelbox representation of a single piece of data (e.g. image, video, text). Attributes: external_id (str): User-generated file name or identifier @@ -49,6 +61,7 @@ class DataRow(DbObject, Updateable, BulkDeletable): labels (Relationship): `ToMany` relationship to Label attachments (Relationship) `ToMany` relationship with AssetAttachment """ + external_id = Field.String("external_id") global_key = Field.String("global_key") row_data = Field.String("row_data") @@ -59,11 +72,14 @@ class DataRow(DbObject, Updateable, BulkDeletable): dict, graphql_type="DataRowCustomMetadataUpsertInput!", name="metadata_fields", - result_subquery="metadataFields { schemaId name value kind }") - metadata = Field.List(DataRowMetadataField, - name="metadata", - graphql_name="customMetadata", - result_subquery="customMetadata { schemaId value }") + result_subquery="metadataFields { schemaId name value kind }", + ) + metadata = Field.List( + DataRowMetadataField, + name="metadata", + graphql_name="customMetadata", + result_subquery="customMetadata { schemaId value }", + ) # Relationships dataset = Relationship.ToOne("Dataset") @@ -73,7 +89,8 @@ class DataRow(DbObject, Updateable, BulkDeletable): attachments = Relationship.ToMany("AssetAttachment", False, "attachments") supported_meta_types = supported_attachment_types = set( - AttachmentType.__members__) + AttachmentType.__members__ + ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -95,12 +112,12 @@ def update(self, **kwargs): row_data = kwargs.get("row_data") if isinstance(row_data, dict): - kwargs['row_data'] = json.dumps(row_data) + kwargs["row_data"] = json.dumps(row_data) super().update(**kwargs) @staticmethod def bulk_delete(data_rows) -> None: - """ Deletes all the given DataRows. + """Deletes all the given DataRows. Args: data_rows (list of DataRow): The DataRows to delete. @@ -108,7 +125,7 @@ def bulk_delete(data_rows) -> None: BulkDeletable._bulk_delete(data_rows, True) def get_winning_label_id(self, project_id: str) -> Optional[str]: - """ Retrieves the winning label ID, i.e. the one that was marked as the + """Retrieves the winning label ID, i.e. the one that was marked as the best for a particular data row, in a project's workflow. Args: @@ -121,21 +138,27 @@ def get_winning_label_id(self, project_id: str) -> Optional[str]: labelingActivity(where: { projectId: $%s }) { selectedLabelId } - }} """ % (data_row_id_param, project_id_param, data_row_id_param, - project_id_param) + }} """ % ( + data_row_id_param, + project_id_param, + data_row_id_param, + project_id_param, + ) - res = self.client.execute(query_str, { - data_row_id_param: self.uid, - project_id_param: project_id, - }) + res = self.client.execute( + query_str, + { + data_row_id_param: self.uid, + project_id_param: project_id, + }, + ) return res["dataRow"]["labelingActivity"]["selectedLabelId"] - def create_attachment(self, - attachment_type, - attachment_value, - attachment_name=None) -> "AssetAttachment": - """ Adds an AssetAttachment to a DataRow. + def create_attachment( + self, attachment_type, attachment_value, attachment_name=None + ) -> "AssetAttachment": + """Adds an AssetAttachment to a DataRow. Labelers can view these attachments while labeling. >>> datarow.create_attachment("TEXT", "This is a text message") @@ -151,10 +174,9 @@ def create_attachment(self, ValueError: attachment_type must be one of the supported types. ValueError: attachment_value must be a non-empty string. """ - Entity.AssetAttachment.validate_attachment_json({ - 'type': attachment_type, - 'value': attachment_value - }) + Entity.AssetAttachment.validate_attachment_json( + {"type": attachment_type, "value": attachment_value} + ) attachment_type_param = "type" attachment_value_param = "value" @@ -165,20 +187,29 @@ def create_attachment(self, $%s: AttachmentType!, $%s: String!, $%s: String, $%s: ID!) { createDataRowAttachment(data: { type: $%s value: $%s name: $%s dataRowId: $%s}) {%s}} """ % ( - attachment_type_param, attachment_value_param, - attachment_name_param, data_row_id_param, attachment_type_param, - attachment_value_param, attachment_name_param, data_row_id_param, - query.results_query_part(Entity.AssetAttachment)) + attachment_type_param, + attachment_value_param, + attachment_name_param, + data_row_id_param, + attachment_type_param, + attachment_value_param, + attachment_name_param, + data_row_id_param, + query.results_query_part(Entity.AssetAttachment), + ) res = self.client.execute( - query_str, { + query_str, + { attachment_type_param: attachment_type, attachment_value_param: attachment_value, attachment_name_param: attachment_name, - data_row_id_param: self.uid - }) - return Entity.AssetAttachment(self.client, - res["createDataRowAttachment"]) + data_row_id_param: self.uid, + }, + ) + return Entity.AssetAttachment( + self.client, res["createDataRowAttachment"] + ) @staticmethod def export( @@ -210,12 +241,9 @@ def export( >>> task.wait_till_done() >>> task.result """ - task, _ = DataRow._export(client, - data_rows, - global_keys, - task_name, - params, - streamable=True) + task, _ = DataRow._export( + client, data_rows, global_keys, task_name, params, streamable=True + ) return ExportTask(task) @staticmethod @@ -249,8 +277,9 @@ def export_v2( >>> task.wait_till_done() >>> task.result """ - task, is_streamable = DataRow._export(client, data_rows, global_keys, - task_name, params) + task, is_streamable = DataRow._export( + client, data_rows, global_keys, task_name, params + ) if is_streamable: return ExportTask(task, True) return task @@ -264,21 +293,23 @@ def _export( params: Optional[CatalogExportParams] = None, streamable: bool = False, ) -> Tuple[Task, bool]: - _params = params or CatalogExportParams({ - "attachments": False, - "embeddings": False, - "metadata_fields": False, - "data_row_details": False, - "project_details": False, - "performance_details": False, - "label_details": False, - "media_type_override": None, - "model_run_ids": None, - "project_ids": None, - "interpolated_frames": False, - "all_projects": False, - "all_model_runs": False, - }) + _params = params or CatalogExportParams( + { + "attachments": False, + "embeddings": False, + "metadata_fields": False, + "data_row_details": False, + "project_details": False, + "performance_details": False, + "label_details": False, + "media_type_override": None, + "model_run_ids": None, + "project_ids": None, + "interpolated_frames": False, + "all_projects": False, + "all_model_runs": False, + } + ) validate_catalog_export_params(_params) @@ -286,7 +317,8 @@ def _export( create_task_query_str = ( f"mutation {mutation_name}PyApi" f"($input: ExportDataRowsInCatalogInput!)" - f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}") + f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}" + ) data_row_ids = [] if data_rows is not None: @@ -296,17 +328,25 @@ def _export( elif isinstance(dr, str): data_row_ids.append(dr) - filters = DatarowExportFilters({ - "data_row_ids": data_row_ids, - "global_keys": None, - }) if data_row_ids else DatarowExportFilters({ - "data_row_ids": None, - "global_keys": global_keys, - }) + filters = ( + DatarowExportFilters( + { + "data_row_ids": data_row_ids, + "global_keys": None, + } + ) + if data_row_ids + else DatarowExportFilters( + { + "data_row_ids": None, + "global_keys": global_keys, + } + ) + ) validate_at_least_one_of_data_row_ids_or_global_keys(filters) search_query = build_filters(client, filters) - media_type_override = _params.get('media_type_override', None) + media_type_override = _params.get("media_type_override", None) if task_name is None: task_name = f"Export v2: data rows {len(data_row_ids)}" @@ -314,48 +354,41 @@ def _export( "input": { "taskName": task_name, "filters": { - "searchQuery": { - "scope": None, - "query": search_query - } + "searchQuery": {"scope": None, "query": search_query} }, "isStreamableReady": True, "params": { - "mediaTypeOverride": - media_type_override.value - if media_type_override is not None else None, - "includeAttachments": - _params.get('attachments', False), - "includeEmbeddings": - _params.get('embeddings', False), - "includeMetadata": - _params.get('metadata_fields', False), - "includeDataRowDetails": - _params.get('data_row_details', False), - "includeProjectDetails": - _params.get('project_details', False), - "includePerformanceDetails": - _params.get('performance_details', False), - "includeLabelDetails": - _params.get('label_details', False), - "includeInterpolatedFrames": - _params.get('interpolated_frames', False), - "projectIds": - _params.get('project_ids', None), - "modelRunIds": - _params.get('model_run_ids', None), - "allProjects": - _params.get('all_projects', False), - "allModelRuns": - _params.get('all_model_runs', False), + "mediaTypeOverride": media_type_override.value + if media_type_override is not None + else None, + "includeAttachments": _params.get("attachments", False), + "includeEmbeddings": _params.get("embeddings", False), + "includeMetadata": _params.get("metadata_fields", False), + "includeDataRowDetails": _params.get( + "data_row_details", False + ), + "includeProjectDetails": _params.get( + "project_details", False + ), + "includePerformanceDetails": _params.get( + "performance_details", False + ), + "includeLabelDetails": _params.get("label_details", False), + "includeInterpolatedFrames": _params.get( + "interpolated_frames", False + ), + "projectIds": _params.get("project_ids", None), + "modelRunIds": _params.get("model_run_ids", None), + "allProjects": _params.get("all_projects", False), + "allModelRuns": _params.get("all_model_runs", False), }, - "streamable": streamable + "streamable": streamable, } } - res = client.execute(create_task_query_str, - query_params, - error_log_key="errors") + res = client.execute( + create_task_query_str, query_params, error_log_key="errors" + ) res = res[mutation_name] task_id = res["taskId"] is_streamable = res["isStreamable"] diff --git a/libs/labelbox/src/labelbox/schema/data_row_metadata.py b/libs/labelbox/src/labelbox/schema/data_row_metadata.py index cb02c32f8..288459a89 100644 --- a/libs/labelbox/src/labelbox/schema/data_row_metadata.py +++ b/libs/labelbox/src/labelbox/schema/data_row_metadata.py @@ -5,15 +5,36 @@ from itertools import chain import warnings -from typing import List, Optional, Dict, Union, Callable, Type, Any, Generator, overload +from typing import ( + List, + Optional, + Dict, + Union, + Callable, + Type, + Any, + Generator, + overload, +) from typing_extensions import Annotated from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds from labelbox.schema.identifiable import UniqueId, GlobalKey -from pydantic import BaseModel, Field, StringConstraints, conlist, ConfigDict, model_serializer +from pydantic import ( + BaseModel, + Field, + StringConstraints, + conlist, + ConfigDict, + model_serializer, +) from labelbox.schema.ontology import SchemaId -from labelbox.utils import _CamelCaseMixin, format_iso_datetime, format_iso_from_string +from labelbox.utils import ( + _CamelCaseMixin, + format_iso_datetime, + format_iso_from_string, +) class DataRowMetadataKind(Enum): @@ -28,9 +49,7 @@ class DataRowMetadataKind(Enum): # Metadata schema class DataRowMetadataSchema(BaseModel): uid: SchemaId - name: str = Field(strip_whitespace=True, - min_length=1, - max_length=100) + name: str = Field(strip_whitespace=True, min_length=1, max_length=100) reserved: bool kind: DataRowMetadataKind options: Optional[List["DataRowMetadataSchema"]] = None @@ -39,9 +58,7 @@ class DataRowMetadataSchema(BaseModel): DataRowMetadataSchema.model_rebuild() -Embedding: Type[List[float]] = conlist(float, - min_length=128, - max_length=128) +Embedding: Type[List[float]] = conlist(float, min_length=128, max_length=128) String: Type[str] = Field(max_length=4096) @@ -95,49 +112,53 @@ class _UpsertBatchDataRowMetadata(_CamelCaseMixin): class _DeleteBatchDataRowMetadata(_CamelCaseMixin): data_row_identifier: Union[UniqueId, GlobalKey] schema_ids: List[SchemaId] - + model_config = ConfigDict(arbitrary_types_allowed=True) - + @model_serializer(mode="wrap") def model_serializer(self, handler): res = handler(self) - if 'data_row_identifier' in res.keys(): - key = 'data_row_identifier' - id_type_key = 'id_type' + if "data_row_identifier" in res.keys(): + key = "data_row_identifier" + id_type_key = "id_type" else: - key = 'dataRowIdentifier' - id_type_key = 'idType' + key = "dataRowIdentifier" + id_type_key = "idType" data_row_identifier = res.pop(key) res[key] = { "id": data_row_identifier.key, - id_type_key: data_row_identifier.id_type + id_type_key: data_row_identifier.id_type, } return res -_BatchInputs = Union[List[_UpsertBatchDataRowMetadata], - List[_DeleteBatchDataRowMetadata]] +_BatchInputs = Union[ + List[_UpsertBatchDataRowMetadata], List[_DeleteBatchDataRowMetadata] +] _BatchFunction = Callable[[_BatchInputs], List[DataRowMetadataBatchResponse]] class _UpsertCustomMetadataSchemaEnumOptionInput(_CamelCaseMixin): id: Optional[SchemaId] = None - name: Annotated[str, StringConstraints(strip_whitespace=True, - min_length=1, - max_length=100)] + name: Annotated[ + str, + StringConstraints(strip_whitespace=True, min_length=1, max_length=100), + ] kind: str + class _UpsertCustomMetadataSchemaInput(_CamelCaseMixin): id: Optional[SchemaId] = None - name: Annotated[str, StringConstraints(strip_whitespace=True, - min_length=1, - max_length=100)] + name: Annotated[ + str, + StringConstraints(strip_whitespace=True, min_length=1, max_length=100), + ] kind: str options: Optional[List[_UpsertCustomMetadataSchemaEnumOptionInput]] = None class DataRowMetadataOntology: - """ Ontology for data row metadata + """Ontology for data row metadata Metadata provides additional context for a data rows. Metadata is broken into two classes reserved and custom. Reserved fields are defined by Labelbox and used for creating @@ -148,7 +169,6 @@ class DataRowMetadataOntology: """ def __init__(self, client): - self._client = client self._batch_size = 50 # used for uploads and deletes @@ -165,24 +185,24 @@ def _build_ontology(self): f for f in self.fields if f.reserved ] self.reserved_by_id = self._make_id_index(self.reserved_fields) - self.reserved_by_name: Dict[str, Union[DataRowMetadataSchema, Dict[ - str, DataRowMetadataSchema]]] = self._make_name_index( - self.reserved_fields) - self.reserved_by_name_normalized: Dict[ - str, DataRowMetadataSchema] = self._make_normalized_name_index( - self.reserved_fields) + self.reserved_by_name: Dict[ + str, Union[DataRowMetadataSchema, Dict[str, DataRowMetadataSchema]] + ] = self._make_name_index(self.reserved_fields) + self.reserved_by_name_normalized: Dict[str, DataRowMetadataSchema] = ( + self._make_normalized_name_index(self.reserved_fields) + ) # custom fields self.custom_fields: List[DataRowMetadataSchema] = [ f for f in self.fields if not f.reserved ] self.custom_by_id = self._make_id_index(self.custom_fields) - self.custom_by_name: Dict[str, Union[DataRowMetadataSchema, Dict[ - str, - DataRowMetadataSchema]]] = self._make_name_index(self.custom_fields) - self.custom_by_name_normalized: Dict[ - str, DataRowMetadataSchema] = self._make_normalized_name_index( - self.custom_fields) + self.custom_by_name: Dict[ + str, Union[DataRowMetadataSchema, Dict[str, DataRowMetadataSchema]] + ] = self._make_name_index(self.custom_fields) + self.custom_by_name_normalized: Dict[str, DataRowMetadataSchema] = ( + self._make_normalized_name_index(self.custom_fields) + ) @staticmethod def _lookup_in_index_by_name(reserved_index, custom_index, name): @@ -197,7 +217,7 @@ def _lookup_in_index_by_name(reserved_index, custom_index, name): def get_by_name( self, name: str ) -> Union[DataRowMetadataSchema, Dict[str, DataRowMetadataSchema]]: - """ Get metadata by name + """Get metadata by name >>> mdo.get_by_name(name) @@ -210,23 +230,27 @@ def get_by_name( Raises: KeyError: When provided name is not presented in neither reserved nor custom metadata list """ - return self._lookup_in_index_by_name(self.reserved_by_name, - self.custom_by_name, name) + return self._lookup_in_index_by_name( + self.reserved_by_name, self.custom_by_name, name + ) def _get_by_name_normalized(self, name: str) -> DataRowMetadataSchema: - """ Get metadata by name. For options, it provides the option schema instead of list of - options + """Get metadata by name. For options, it provides the option schema instead of list of + options """ # using `normalized` indices to find options by name as well - return self._lookup_in_index_by_name(self.reserved_by_name_normalized, - self.custom_by_name_normalized, - name) + return self._lookup_in_index_by_name( + self.reserved_by_name_normalized, + self.custom_by_name_normalized, + name, + ) @staticmethod def _make_name_index( - fields: List[DataRowMetadataSchema] - ) -> Dict[str, Union[DataRowMetadataSchema, Dict[str, - DataRowMetadataSchema]]]: + fields: List[DataRowMetadataSchema], + ) -> Dict[ + str, Union[DataRowMetadataSchema, Dict[str, DataRowMetadataSchema]] + ]: index = {} for f in fields: if f.options: @@ -239,7 +263,7 @@ def _make_name_index( @staticmethod def _make_normalized_name_index( - fields: List[DataRowMetadataSchema] + fields: List[DataRowMetadataSchema], ) -> Dict[str, DataRowMetadataSchema]: index = {} for f in fields: @@ -248,7 +272,7 @@ def _make_normalized_name_index( @staticmethod def _make_id_index( - fields: List[DataRowMetadataSchema] + fields: List[DataRowMetadataSchema], ) -> Dict[SchemaId, DataRowMetadataSchema]: index = {} for f in fields: @@ -287,29 +311,26 @@ def _parse_ontology(raw_ontology) -> List[DataRowMetadataSchema]: for option in schema["options"]: option["uid"] = option["id"] options.append( - DataRowMetadataSchema(**{ - **option, - **{ - "parent": schema["uid"] - } - })) + DataRowMetadataSchema( + **{**option, **{"parent": schema["uid"]}} + ) + ) schema["options"] = options fields.append(DataRowMetadataSchema(**schema)) return fields def refresh_ontology(self): - """ Update the `DataRowMetadataOntology` instance with the latest - metadata ontology schemas + """Update the `DataRowMetadataOntology` instance with the latest + metadata ontology schemas """ self._raw_ontology = self._get_ontology() self._build_ontology() - def create_schema(self, - name: str, - kind: DataRowMetadataKind, - options: List[str] = None) -> DataRowMetadataSchema: - """ Create metadata schema + def create_schema( + self, name: str, kind: DataRowMetadataKind, options: List[str] = None + ) -> DataRowMetadataSchema: + """Create metadata schema >>> mdo.create_schema(name, kind, options) @@ -327,8 +348,9 @@ def create_schema(self, if not isinstance(kind, DataRowMetadataKind): raise ValueError(f"kind '{kind}' must be a `DataRowMetadataKind`") - upsert_schema = _UpsertCustomMetadataSchemaInput(name=name, - kind=kind.value) + upsert_schema = _UpsertCustomMetadataSchemaInput( + name=name, kind=kind.value + ) if options: if kind != DataRowMetadataKind.enum: raise ValueError( @@ -336,7 +358,8 @@ def create_schema(self, ) upsert_enum_options = [ _UpsertCustomMetadataSchemaEnumOptionInput( - name=o, kind=DataRowMetadataKind.option.value) + name=o, kind=DataRowMetadataKind.option.value + ) for o in options ] upsert_schema.options = upsert_enum_options @@ -344,7 +367,7 @@ def create_schema(self, return self._upsert_schema(upsert_schema) def update_schema(self, name: str, new_name: str) -> DataRowMetadataSchema: - """ Update metadata schema + """Update metadata schema >>> mdo.update_schema(name, new_name) @@ -359,24 +382,24 @@ def update_schema(self, name: str, new_name: str) -> DataRowMetadataSchema: KeyError: When provided name is not a valid custom metadata """ schema = self._validate_custom_schema_by_name(name) - upsert_schema = _UpsertCustomMetadataSchemaInput(id=schema.uid, - name=new_name, - kind=schema.kind.value) + upsert_schema = _UpsertCustomMetadataSchemaInput( + id=schema.uid, name=new_name, kind=schema.kind.value + ) if schema.options: upsert_enum_options = [ _UpsertCustomMetadataSchemaEnumOptionInput( - id=o.uid, - name=o.name, - kind=DataRowMetadataKind.option.value) + id=o.uid, name=o.name, kind=DataRowMetadataKind.option.value + ) for o in schema.options ] upsert_schema.options = upsert_enum_options return self._upsert_schema(upsert_schema) - def update_enum_option(self, name: str, option: str, - new_option: str) -> DataRowMetadataSchema: - """ Update Enum metadata schema option + def update_enum_option( + self, name: str, option: str, new_option: str + ) -> DataRowMetadataSchema: + """Update Enum metadata schema option >>> mdo.update_enum_option(name, option, new_option) @@ -402,13 +425,14 @@ def update_enum_option(self, name: str, option: str, raise ValueError( f"Enum option '{option}' is not a valid option for Enum '{name}', valid options are: {valid_options}" ) - upsert_schema = _UpsertCustomMetadataSchemaInput(id=schema.uid, - name=schema.name, - kind=schema.kind.value) + upsert_schema = _UpsertCustomMetadataSchemaInput( + id=schema.uid, name=schema.name, kind=schema.kind.value + ) upsert_enum_options = [] for o in schema.options: enum_option = _UpsertCustomMetadataSchemaEnumOptionInput( - id=o.uid, name=o.name, kind=o.kind.value) + id=o.uid, name=o.name, kind=o.kind.value + ) if enum_option.name == option: enum_option.name = new_option upsert_enum_options.append(enum_option) @@ -417,7 +441,7 @@ def update_enum_option(self, name: str, option: str, return self._upsert_schema(upsert_schema) def delete_schema(self, name: str) -> bool: - """ Delete metadata schema + """Delete metadata schema >>> mdo.delete_schema(name) @@ -436,18 +460,17 @@ def delete_schema(self, name: str) -> bool: success } }""" - res = self._client.execute(query, {'where': { - 'id': schema.uid - }})['deleteCustomMetadataSchema'] + res = self._client.execute(query, {"where": {"id": schema.uid}})[ + "deleteCustomMetadataSchema" + ] self.refresh_ontology() - return res['success'] + return res["success"] def parse_metadata( - self, unparsed: List[Dict[str, - List[Union[str, - Dict]]]]) -> List[DataRowMetadata]: - """ Parse metadata responses + self, unparsed: List[Dict[str, List[Union[str, Dict]]]] + ) -> List[DataRowMetadata]: + """Parse metadata responses >>> mdo.parse_metadata([metadata]) @@ -466,15 +489,18 @@ def parse_metadata( if "fields" in dr: fields = self.parse_metadata_fields(dr["fields"]) parsed.append( - DataRowMetadata(data_row_id=dr["dataRowId"], - global_key=dr["globalKey"], - fields=fields)) + DataRowMetadata( + data_row_id=dr["dataRowId"], + global_key=dr["globalKey"], + fields=fields, + ) + ) return parsed def parse_metadata_fields( - self, unparsed: List[Dict[str, - Dict]]) -> List[DataRowMetadataField]: - """ Parse metadata fields as list of `DataRowMetadataField` + self, unparsed: List[Dict[str, Dict]] + ) -> List[DataRowMetadataField]: + """Parse metadata fields as list of `DataRowMetadataField` >>> mdo.parse_metadata_fields([metadata_fields]) @@ -494,31 +520,35 @@ def parse_metadata_fields( self.refresh_ontology() if f["schemaId"] not in self.fields_by_id: raise ValueError( - f"Schema Id `{f['schemaId']}` not found in ontology") + f"Schema Id `{f['schemaId']}` not found in ontology" + ) schema = self.fields_by_id[f["schemaId"]] if schema.kind == DataRowMetadataKind.enum: continue elif schema.kind == DataRowMetadataKind.option: - field = DataRowMetadataField(schema_id=schema.parent, - value=schema.uid) + field = DataRowMetadataField( + schema_id=schema.parent, value=schema.uid + ) elif schema.kind == DataRowMetadataKind.datetime: - field = DataRowMetadataField(schema_id=schema.uid, - value=format_iso_from_string( - f["value"])) + field = DataRowMetadataField( + schema_id=schema.uid, + value=format_iso_from_string(f["value"]), + ) else: - field = DataRowMetadataField(schema_id=schema.uid, - value=f["value"]) + field = DataRowMetadataField( + schema_id=schema.uid, value=f["value"] + ) field.name = schema.name parsed.append(field) return parsed def bulk_upsert( - self, metadata: List[DataRowMetadata] + self, metadata: List[DataRowMetadata] ) -> List[DataRowMetadataBatchResponse]: """Upsert metadata to a list of data rows - + You may specify data row by either data_row_id or global_key >>> metadata = DataRowMetadata( @@ -542,7 +572,7 @@ def bulk_upsert( raise ValueError("Empty list passed") def _batch_upsert( - upserts: List[_UpsertBatchDataRowMetadata] + upserts: List[_UpsertBatchDataRowMetadata], ) -> List[DataRowMetadataBatchResponse]: query = """mutation UpsertDataRowMetadataBetaPyApi($metadata: [DataRowCustomMetadataBatchUpsertInput!]!) { upsertDataRowCustomMetadata(data: $metadata){ @@ -555,14 +585,17 @@ def _batch_upsert( } } }""" - res = self._client.execute( - query, {"metadata": upserts})['upsertDataRowCustomMetadata'] + res = self._client.execute(query, {"metadata": upserts})[ + "upsertDataRowCustomMetadata" + ] return [ - DataRowMetadataBatchResponse(global_key=r['globalKey'], - data_row_id=r['dataRowId'], - error=r['error'], - fields=self.parse_metadata( - [r])[0].fields) for r in res + DataRowMetadataBatchResponse( + global_key=r["globalKey"], + data_row_id=r["dataRowId"], + error=r["error"], + fields=self.parse_metadata([r])[0].fields, + ) + for r in res ] items = [] @@ -574,14 +607,18 @@ def _batch_upsert( fields=list( chain.from_iterable( self._parse_upsert(f, m.data_row_id) - for f in m.fields))).model_dump(by_alias=True)) + for f in m.fields + ) + ), + ).model_dump(by_alias=True) + ) res = _batch_operations(_batch_upsert, items, self._batch_size) return res def bulk_delete( self, deletes: List[DeleteDataRowMetadata] ) -> List[DataRowMetadataBatchResponse]: - """ Delete metadata from a datarow by specifiying the fields you want to remove + """Delete metadata from a datarow by specifiying the fields you want to remove >>> delete = DeleteDataRowMetadata( >>> data_row_id=UniqueId("datarow-id"), @@ -616,7 +653,7 @@ def bulk_delete( Args: deletes: Data row and schema ids to delete - For data row, we support UniqueId, str, and GlobalKey. + For data row, we support UniqueId, str, and GlobalKey. If you pass a str, we will assume it is a UniqueId Do not pass a mix of data row ids and global keys in the same list @@ -633,9 +670,10 @@ def bulk_delete( for i, delete in enumerate(deletes): if isinstance(delete.data_row_id, str): passed_strings = True - deletes[i] = DeleteDataRowMetadata(data_row_id=UniqueId( - delete.data_row_id), - fields=delete.fields) + deletes[i] = DeleteDataRowMetadata( + data_row_id=UniqueId(delete.data_row_id), + fields=delete.fields, + ) elif isinstance(delete.data_row_id, UniqueId): continue elif isinstance(delete.data_row_id, GlobalKey): @@ -648,10 +686,11 @@ def bulk_delete( if passed_strings: warnings.warn( "Using string for data row id will be deprecated. Please use " - "UniqueId instead.") + "UniqueId instead." + ) def _batch_delete( - deletes: List[_DeleteBatchDataRowMetadata] + deletes: List[_DeleteBatchDataRowMetadata], ) -> List[DataRowMetadataBatchResponse]: query = """mutation DeleteDataRowMetadataBetaPyApi($deletes: [DataRowIdentifierCustomMetadataBatchDeleteInput!]) { deleteDataRowCustomMetadata(dataRowIdentifiers: $deletes) { @@ -664,30 +703,32 @@ def _batch_delete( } } """ - res = self._client.execute( - query, {"deletes": deletes})['deleteDataRowCustomMetadata'] + res = self._client.execute(query, {"deletes": deletes})[ + "deleteDataRowCustomMetadata" + ] failures = [] for dr in res: - dr['fields'] = [f['schemaId'] for f in dr['fields']] + dr["fields"] = [f["schemaId"] for f in dr["fields"]] failures.append(DataRowMetadataBatchResponse(**dr)) return failures items = [self._validate_delete(m) for m in deletes] - return _batch_operations(_batch_delete, - items, - batch_size=self._batch_size) + return _batch_operations( + _batch_delete, items, batch_size=self._batch_size + ) @overload def bulk_export(self, data_row_ids: List[str]) -> List[DataRowMetadata]: pass @overload - def bulk_export(self, - data_row_ids: DataRowIdentifiers) -> List[DataRowMetadata]: + def bulk_export( + self, data_row_ids: DataRowIdentifiers + ) -> List[DataRowMetadata]: pass def bulk_export(self, data_row_ids) -> List[DataRowMetadata]: - """ Exports metadata for a list of data rows + """Exports metadata for a list of data rows >>> mdo.bulk_export([data_row.uid for data_row in data_rows]) @@ -704,15 +745,20 @@ def bulk_export(self, data_row_ids) -> List[DataRowMetadata]: if not len(data_row_ids): raise ValueError("Empty list passed") - if isinstance(data_row_ids, - list) and len(data_row_ids) > 0 and isinstance( - data_row_ids[0], str): + if ( + isinstance(data_row_ids, list) + and len(data_row_ids) > 0 + and isinstance(data_row_ids[0], str) + ): data_row_ids = UniqueIds(data_row_ids) - warnings.warn("Using data row ids will be deprecated. Please use " - "UniqueIds or GlobalKeys instead.") + warnings.warn( + "Using data row ids will be deprecated. Please use " + "UniqueIds or GlobalKeys instead." + ) def _bulk_export( - _data_row_ids: DataRowIdentifiers) -> List[DataRowMetadata]: + _data_row_ids: DataRowIdentifiers, + ) -> List[DataRowMetadata]: query = """query dataRowCustomMetadataPyApi($dataRowIdentifiers: DataRowCustomMetadataDataRowIdentifiersInput) { dataRowCustomMetadata(where: {dataRowIdentifiers : $dataRowIdentifiers}) { dataRowId @@ -726,19 +772,22 @@ def _bulk_export( """ return self.parse_metadata( self._client.execute( - query, { + query, + { "dataRowIdentifiers": { "ids": [id for id in _data_row_ids], - "idType": _data_row_ids.id_type + "idType": _data_row_ids.id_type, } - })['dataRowCustomMetadata']) + }, + )["dataRowCustomMetadata"] + ) - return _batch_operations(_bulk_export, - data_row_ids, - batch_size=self._batch_size) + return _batch_operations( + _bulk_export, data_row_ids, batch_size=self._batch_size + ) def parse_upsert_metadata(self, metadata_fields) -> List[Dict[str, Any]]: - """ Converts either `DataRowMetadataField` or a dictionary representation + """Converts either `DataRowMetadataField` or a dictionary representation of `DataRowMetadataField` into a validated, flattened dictionary of metadata fields that are used to create data row metadata. Used internally in `Dataset.create_data_rows()` @@ -758,14 +807,18 @@ def _convert_metadata_field(metadata_field): raise ValueError( f"Custom metadata field '{metadata_field}' must have a 'value' key" ) - if not "schema_id" in metadata_field and not "name" in metadata_field: + if ( + not "schema_id" in metadata_field + and not "name" in metadata_field + ): raise ValueError( f"Custom metadata field '{metadata_field}' must have either 'schema_id' or 'name' key" ) return DataRowMetadataField( schema_id=metadata_field.get("schema_id"), name=metadata_field.get("name"), - value=metadata_field["value"]) + value=metadata_field["value"], + ) else: raise ValueError( f"Metadata field '{metadata_field}' is neither 'DataRowMetadataField' type or a dictionary" @@ -774,7 +827,8 @@ def _convert_metadata_field(metadata_field): # Convert all metadata fields to DataRowMetadataField type metadata_fields = [_convert_metadata_field(m) for m in metadata_fields] parsed_metadata = list( - chain.from_iterable(self._parse_upsert(m) for m in metadata_fields)) + chain.from_iterable(self._parse_upsert(m) for m in metadata_fields) + ) return [m.model_dump(by_alias=True) for m in parsed_metadata] def _upsert_schema( @@ -793,8 +847,8 @@ def _upsert_schema( } }""" res = self._client.execute( - query, {"data": upsert_schema.model_dump(exclude_none=True) - })['upsertCustomMetadataSchema'] + query, {"data": upsert_schema.model_dump(exclude_none=True)} + )["upsertCustomMetadataSchema"] self.refresh_ontology() return _parse_metadata_schema(res) @@ -822,9 +876,7 @@ def _load_schema_id_by_name(self, metadatum: DataRowMetadataField): self._load_option_by_name(metadatum) def _parse_upsert( - self, - metadatum: DataRowMetadataField, - data_row_id: Optional[str] = None + self, metadatum: DataRowMetadataField, data_row_id: Optional[str] = None ) -> List[_UpsertDataRowMetadataInput]: """Format for metadata upserts to GQL""" @@ -835,7 +887,8 @@ def _parse_upsert( self.refresh_ontology() if metadatum.schema_id not in self.fields_by_id: raise ValueError( - f"Schema Id `{metadatum.schema_id}` not found in ontology") + f"Schema Id `{metadatum.schema_id}` not found in ontology" + ) schema = self.fields_by_id[metadatum.schema_id] try: @@ -851,7 +904,8 @@ def _parse_upsert( parsed = _validate_enum_parse(schema, metadatum) elif schema.kind == DataRowMetadataKind.option: raise ValueError( - "An Option id should not be set as the Schema id") + "An Option id should not be set as the Schema id" + ) else: raise ValueError(f"Unknown type: {schema}") except ValueError as e: @@ -872,7 +926,8 @@ def _validate_delete(self, delete: DeleteDataRowMetadata): self.refresh_ontology() if schema_id not in self.fields_by_id: raise ValueError( - f"Schema Id `{schema_id}` not found in ontology") + f"Schema Id `{schema_id}` not found in ontology" + ) schema = self.fields_by_id[schema_id] # handle users specifying enums by adding all option enums @@ -883,10 +938,12 @@ def _validate_delete(self, delete: DeleteDataRowMetadata): return _DeleteBatchDataRowMetadata( data_row_identifier=delete.data_row_id, - schema_ids=list(delete.fields)).model_dump(by_alias=True) + schema_ids=list(delete.fields), + ).model_dump(by_alias=True) - def _validate_custom_schema_by_name(self, - name: str) -> DataRowMetadataSchema: + def _validate_custom_schema_by_name( + self, name: str + ) -> DataRowMetadataSchema: if name not in self.custom_by_name_normalized: # Fetch latest metadata ontology if metadata can't be found self.refresh_ontology() @@ -899,7 +956,7 @@ def _validate_custom_schema_by_name(self, def _batch_items(iterable: List[Any], size: int) -> Generator[Any, None, None]: l = len(iterable) for ndx in range(0, l, size): - yield iterable[ndx:min(ndx + size, l)] + yield iterable[ndx : min(ndx + size, l)] def _batch_operations( @@ -915,9 +972,8 @@ def _batch_operations( def _validate_parse_embedding( - field: DataRowMetadataField + field: DataRowMetadataField, ) -> List[Dict[str, Union[SchemaId, Embedding]]]: - if isinstance(field.value, list): if not (Embedding.min_items <= len(field.value) <= Embedding.max_items): raise ValueError( @@ -928,19 +984,21 @@ def _validate_parse_embedding( field.value = [float(x) for x in field.value] else: raise ValueError( - f"Expected a list for embedding. Found {type(field.value)}") + f"Expected a list for embedding. Found {type(field.value)}" + ) return [field.model_dump(by_alias=True)] def _validate_parse_number( - field: DataRowMetadataField + field: DataRowMetadataField, ) -> List[Dict[str, Union[SchemaId, str, float, int]]]: field.value = float(field.value) return [field.model_dump(by_alias=True)] def _validate_parse_datetime( - field: DataRowMetadataField) -> List[Dict[str, Union[SchemaId, str]]]: + field: DataRowMetadataField, +) -> List[Dict[str, Union[SchemaId, str]]]: if isinstance(field.value, str): field.value = format_iso_from_string(field.value) elif not isinstance(field.value, datetime): @@ -948,57 +1006,58 @@ def _validate_parse_datetime( f"Value for datetime fields must be either a string or datetime object. Found {type(field.value)}" ) - return [{ - "schemaId": field.schema_id, - "value": format_iso_datetime(field.value) - }] + return [ + {"schemaId": field.schema_id, "value": format_iso_datetime(field.value)} + ] def _validate_parse_text( - field: DataRowMetadataField) -> List[Dict[str, Union[SchemaId, str]]]: + field: DataRowMetadataField, +) -> List[Dict[str, Union[SchemaId, str]]]: if not isinstance(field.value, str): raise ValueError( f"Expected a string type for the text field. Found {type(field.value)}" ) if len(field.value) > String.metadata[0].max_length: raise ValueError( - f"String fields cannot exceed {String.metadata.max_length} characters.") + f"String fields cannot exceed {String.metadata.max_length} characters." + ) return [field.model_dump(by_alias=True)] def _validate_enum_parse( - schema: DataRowMetadataSchema, - field: DataRowMetadataField) -> List[Dict[str, Union[SchemaId, dict]]]: + schema: DataRowMetadataSchema, field: DataRowMetadataField +) -> List[Dict[str, Union[SchemaId, dict]]]: if schema.options: if field.value not in {o.uid for o in schema.options}: raise ValueError( - f"Option `{field.value}` not found for {field.schema_id}") + f"Option `{field.value}` not found for {field.schema_id}" + ) else: raise ValueError("Incorrectly specified enum schema") - return [{ - "schemaId": field.schema_id, - "value": {} - }, { - "schemaId": field.value, - "value": {} - }] + return [ + {"schemaId": field.schema_id, "value": {}}, + {"schemaId": field.value, "value": {}}, + ] def _parse_metadata_schema( - unparsed: Dict[str, Union[str, List]]) -> DataRowMetadataSchema: - uid = unparsed['id'] - name = unparsed['name'] - kind = DataRowMetadataKind(unparsed['kind']) + unparsed: Dict[str, Union[str, List]], +) -> DataRowMetadataSchema: + uid = unparsed["id"] + name = unparsed["name"] + kind = DataRowMetadataKind(unparsed["kind"]) options = [ - DataRowMetadataSchema(uid=o['id'], - name=o['name'], - reserved=False, - kind=DataRowMetadataKind.option, - parent=uid) for o in unparsed['options'] + DataRowMetadataSchema( + uid=o["id"], + name=o["name"], + reserved=False, + kind=DataRowMetadataKind.option, + parent=uid, + ) + for o in unparsed["options"] ] - return DataRowMetadataSchema(uid=uid, - name=name, - reserved=False, - kind=kind, - options=options or None) + return DataRowMetadataSchema( + uid=uid, name=name, reserved=False, kind=kind, options=options or None + ) diff --git a/libs/labelbox/src/labelbox/schema/dataset.py b/libs/labelbox/src/labelbox/schema/dataset.py index eaa37c5b7..17a3afc3d 100644 --- a/libs/labelbox/src/labelbox/schema/dataset.py +++ b/libs/labelbox/src/labelbox/schema/dataset.py @@ -15,7 +15,12 @@ from io import StringIO import requests -from labelbox.exceptions import InvalidQueryError, LabelboxError, ResourceNotFoundError, ResourceCreationError +from labelbox.exceptions import ( + InvalidQueryError, + LabelboxError, + ResourceNotFoundError, + ResourceCreationError, +) from labelbox.orm.comparison import Comparison from labelbox.orm.db_object import DbObject, Updateable, Deletable, experimental from labelbox.orm.model import Entity, Field, Relationship @@ -25,25 +30,34 @@ from labelbox.schema.data_row import DataRow from labelbox.schema.embedding import EmbeddingVector from labelbox.schema.export_filters import DatasetExportFilters, build_filters -from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params +from labelbox.schema.export_params import ( + CatalogExportParams, + validate_catalog_export_params, +) from labelbox.schema.export_task import ExportTask from labelbox.schema.identifiable import UniqueId, GlobalKey from labelbox.schema.task import Task, DataUpsertTask from labelbox.schema.user import User from labelbox.schema.iam_integration import IAMIntegration -from labelbox.schema.internal.data_row_upsert_item import (DataRowItemBase, - DataRowUpsertItem, - DataRowCreateItem) +from labelbox.schema.internal.data_row_upsert_item import ( + DataRowItemBase, + DataRowUpsertItem, + DataRowCreateItem, +) import labelbox.schema.internal.data_row_uploader as data_row_uploader -from labelbox.schema.internal.descriptor_file_creator import DescriptorFileCreator +from labelbox.schema.internal.descriptor_file_creator import ( + DescriptorFileCreator, +) from labelbox.schema.internal.datarow_upload_constants import ( - FILE_UPLOAD_THREAD_COUNT, UPSERT_CHUNK_SIZE_BYTES) + FILE_UPLOAD_THREAD_COUNT, + UPSERT_CHUNK_SIZE_BYTES, +) logger = logging.getLogger(__name__) class Dataset(DbObject, Updateable, Deletable): - """ A Dataset is a collection of DataRows. + """A Dataset is a collection of DataRows. Attributes: name (str) @@ -65,8 +79,9 @@ class Dataset(DbObject, Updateable, Deletable): # Relationships created_by = Relationship.ToOne("User", False, "created_by") organization = Relationship.ToOne("Organization", False) - iam_integration = Relationship.ToOne("IAMIntegration", False, - "iam_integration", "signer") + iam_integration = Relationship.ToOne( + "IAMIntegration", False, "iam_integration", "signer" + ) def data_rows( self, @@ -90,8 +105,11 @@ def data_rows( """ page_size = 500 # hardcode to avoid overloading the server - where_param = query.where_as_dict(Entity.DataRow, - where) if where is not None else None + where_param = ( + query.where_as_dict(Entity.DataRow, where) + if where is not None + else None + ) template = Template( """query DatasetDataRowsPyApi($$id: ID!, $$from: ID, $$first: Int, $$where: DatasetDataRowWhereInput) { @@ -101,28 +119,30 @@ def data_rows( pageInfo { hasNextPage startCursor } } } - """) + """ + ) query_str = template.substitute( - datarow_selections=query.results_query_part(Entity.DataRow)) + datarow_selections=query.results_query_part(Entity.DataRow) + ) params = { - 'id': self.uid, - 'from': from_cursor, - 'first': page_size, - 'where': where_param, + "id": self.uid, + "from": from_cursor, + "first": page_size, + "where": where_param, } return PaginatedCollection( client=self.client, query=query_str, params=params, - dereferencing=['datasetDataRows', 'nodes'], + dereferencing=["datasetDataRows", "nodes"], obj_class=Entity.DataRow, - cursor_path=['datasetDataRows', 'pageInfo', 'startCursor'], + cursor_path=["datasetDataRows", "pageInfo", "startCursor"], ) def create_data_row(self, items=None, **kwargs) -> "DataRow": - """ Creates a single DataRow belonging to this dataset. + """Creates a single DataRow belonging to this dataset. >>> dataset.create_data_row(row_data="http://my_site.com/photos/img_01.jpg") Args: @@ -148,7 +168,8 @@ def create_data_row(self, items=None, **kwargs) -> "DataRow": file_upload_thread_count = 1 completed_task = self._create_data_rows_sync( - [args], file_upload_thread_count=file_upload_thread_count) + [args], file_upload_thread_count=file_upload_thread_count + ) res = completed_task.result if res is None or len(res) == 0: @@ -156,13 +177,12 @@ def create_data_row(self, items=None, **kwargs) -> "DataRow": f"Data row upload did not complete, task status {completed_task.status} task id {completed_task.uid}" ) - return self.client.get_data_row(res[0]['id']) + return self.client.get_data_row(res[0]["id"]) def create_data_rows_sync( - self, - items, - file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT) -> None: - """ Synchronously bulk upload data rows. + self, items, file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT + ) -> None: + """Synchronously bulk upload data rows. Use this instead of `Dataset.create_data_rows` for smaller batches of data rows that need to be uploaded quickly. Cannot use this for uploads containing more than 1000 data rows. @@ -184,17 +204,18 @@ def create_data_rows_sync( """ warnings.warn( "This method is deprecated and will be " - "removed in a future release. Please use create_data_rows instead.") + "removed in a future release. Please use create_data_rows instead." + ) self._create_data_rows_sync( - items, file_upload_thread_count=file_upload_thread_count) + items, file_upload_thread_count=file_upload_thread_count + ) return None # Return None if no exception is raised - def _create_data_rows_sync(self, - items, - file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT - ) -> "DataUpsertTask": + def _create_data_rows_sync( + self, items, file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT + ) -> "DataUpsertTask": max_data_rows_supported = 1000 if len(items) > max_data_rows_supported: raise ValueError( @@ -203,15 +224,18 @@ def _create_data_rows_sync(self, ) if file_upload_thread_count < 1: raise ValueError( - "file_upload_thread_count must be a positive integer") + "file_upload_thread_count must be a positive integer" + ) - task: DataUpsertTask = self.create_data_rows(items, - file_upload_thread_count) + task: DataUpsertTask = self.create_data_rows( + items, file_upload_thread_count + ) task.wait_till_done() if task.has_errors(): raise ResourceCreationError( - f"Data row upload errors: {task.errors}", cause=task.uid) + f"Data row upload errors: {task.errors}", cause=task.uid + ) if task.status != "COMPLETE": raise ResourceCreationError( f"Data row upload did not complete, task status {task.status} task id {task.uid}" @@ -219,11 +243,10 @@ def _create_data_rows_sync(self, return task - def create_data_rows(self, - items, - file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT - ) -> "DataUpsertTask": - """ Asynchronously bulk upload data rows + def create_data_rows( + self, items, file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT + ) -> "DataUpsertTask": + """Asynchronously bulk upload data rows Use this instead of `Dataset.create_data_rows_sync` uploads for batches that contain more than 1000 data rows. @@ -249,7 +272,8 @@ def create_data_rows(self, if file_upload_thread_count < 1: raise ValueError( - "file_upload_thread_count must be a positive integer") + "file_upload_thread_count must be a positive integer" + ) # Usage example upload_items = self._separate_and_process_items(items) @@ -265,14 +289,15 @@ def _separate_and_process_items(self, items): return dict_items + dict_string_items def _build_from_local_paths( - self, - items: List[str], - file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT) -> List[dict]: + self, + items: List[str], + file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT, + ) -> List[dict]: uploaded_items = [] def upload_file(item): item_url = self.client.upload_file(item) - return {'row_data': item_url, 'external_id': item} + return {"row_data": item_url, "external_id": item} with ThreadPoolExecutor(file_upload_thread_count) as executor: futures = [ @@ -285,10 +310,10 @@ def upload_file(item): return uploaded_items - def data_rows_for_external_id(self, - external_id, - limit=10) -> List["DataRow"]: - """ Convenience method for getting a multiple `DataRow` belonging to this + def data_rows_for_external_id( + self, external_id, limit=10 + ) -> List["DataRow"]: + """Convenience method for getting a multiple `DataRow` belonging to this `Dataset` that has the given `external_id`. Args: @@ -315,7 +340,7 @@ def data_rows_for_external_id(self, return at_most_data_rows def data_row_for_external_id(self, external_id) -> "DataRow": - """ Convenience method for getting a single `DataRow` belonging to this + """Convenience method for getting a single `DataRow` belonging to this `Dataset` that has the given `external_id`. Args: @@ -329,18 +354,20 @@ def data_row_for_external_id(self, external_id) -> "DataRow": in this `DataSet` with the given external ID, or if there are multiple `DataRows` for it. """ - data_rows = self.data_rows_for_external_id(external_id=external_id, - limit=2) + data_rows = self.data_rows_for_external_id( + external_id=external_id, limit=2 + ) if len(data_rows) > 1: logger.warning( f"More than one data_row has the provided external_id : `%s`. Use function data_rows_for_external_id to fetch all", - external_id) + external_id, + ) return data_rows[0] - def export_data_rows(self, - timeout_seconds=120, - include_metadata: bool = False) -> Generator: - """ Returns a generator that produces all data rows that are currently + def export_data_rows( + self, timeout_seconds=120, include_metadata: bool = False + ) -> Generator: + """Returns a generator that produces all data rows that are currently attached to this dataset. Note: For efficiency, the data are cached for 30 minutes. Newly created data rows will not appear @@ -356,7 +383,8 @@ def export_data_rows(self, """ warnings.warn( "You are currently utilizing exports v1 for this action, which will be deprecated after April 30th, 2024. We recommend transitioning to exports v2. To view export v2 details, visit our docs: https://docs.labelbox.com/reference/label-export", - DeprecationWarning) + DeprecationWarning, + ) id_param = "datasetId" metadata_param = "includeMetadataInput" query_str = """mutation GetDatasetDataRowsExportUrlPyApi($%s: ID!, $%s: Boolean!) @@ -364,10 +392,10 @@ def export_data_rows(self, """ % (id_param, metadata_param, id_param, metadata_param) sleep_time = 2 while True: - res = self.client.execute(query_str, { - id_param: self.uid, - metadata_param: include_metadata - }) + res = self.client.execute( + query_str, + {id_param: self.uid, metadata_param: include_metadata}, + ) res = res["exportDatasetDataRows"] if res["status"] == "COMPLETE": download_url = res["downloadUrl"] @@ -375,7 +403,8 @@ def export_data_rows(self, response.raise_for_status() reader = parser.reader(StringIO(response.text)) return ( - Entity.DataRow(self.client, result) for result in reader) + Entity.DataRow(self.client, result) for result in reader + ) elif res["status"] == "FAILED": raise LabelboxError("Data row export failed.") @@ -385,8 +414,9 @@ def export_data_rows(self, f"Unable to export data rows within {timeout_seconds} seconds." ) - logger.debug("Dataset '%s' data row export, waiting for server...", - self.uid) + logger.debug( + "Dataset '%s' data row export, waiting for server...", self.uid + ) time.sleep(sleep_time) def export( @@ -439,7 +469,7 @@ def export_v2( >>> task.result """ task, is_streamable = self._export(task_name, filters, params) - if (is_streamable): + if is_streamable: return ExportTask(task, True) return task @@ -450,36 +480,41 @@ def _export( params: Optional[CatalogExportParams] = None, streamable: bool = False, ) -> Tuple[Task, bool]: - _params = params or CatalogExportParams({ - "attachments": False, - "embeddings": False, - "metadata_fields": False, - "data_row_details": False, - "project_details": False, - "performance_details": False, - "label_details": False, - "media_type_override": None, - "model_run_ids": None, - "project_ids": None, - "interpolated_frames": False, - "all_projects": False, - "all_model_runs": False, - }) + _params = params or CatalogExportParams( + { + "attachments": False, + "embeddings": False, + "metadata_fields": False, + "data_row_details": False, + "project_details": False, + "performance_details": False, + "label_details": False, + "media_type_override": None, + "model_run_ids": None, + "project_ids": None, + "interpolated_frames": False, + "all_projects": False, + "all_model_runs": False, + } + ) validate_catalog_export_params(_params) - _filters = filters or DatasetExportFilters({ - "last_activity_at": None, - "label_created_at": None, - "data_row_ids": None, - "global_keys": None, - }) + _filters = filters or DatasetExportFilters( + { + "last_activity_at": None, + "label_created_at": None, + "data_row_ids": None, + "global_keys": None, + } + ) mutation_name = "exportDataRowsInCatalog" create_task_query_str = ( f"mutation {mutation_name}PyApi" f"($input: ExportDataRowsInCatalogInput!)" - f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}") - media_type_override = _params.get('media_type_override', None) + f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}" + ) + media_type_override = _params.get("media_type_override", None) if task_name is None: task_name = f"Export v2: dataset - {self.name}" @@ -494,61 +529,53 @@ def _export( }, "isStreamableReady": True, "params": { - "mediaTypeOverride": - media_type_override.value - if media_type_override is not None else None, - "includeAttachments": - _params.get('attachments', False), - "includeEmbeddings": - _params.get('embeddings', False), - "includeMetadata": - _params.get('metadata_fields', False), - "includeDataRowDetails": - _params.get('data_row_details', False), - "includeProjectDetails": - _params.get('project_details', False), - "includePerformanceDetails": - _params.get('performance_details', False), - "includeLabelDetails": - _params.get('label_details', False), - "includeInterpolatedFrames": - _params.get('interpolated_frames', False), - "includePredictions": - _params.get('predictions', False), - "projectIds": - _params.get('project_ids', None), - "modelRunIds": - _params.get('model_run_ids', None), - "allProjects": - _params.get('all_projects', False), - "allModelRuns": - _params.get('all_model_runs', False), + "mediaTypeOverride": media_type_override.value + if media_type_override is not None + else None, + "includeAttachments": _params.get("attachments", False), + "includeEmbeddings": _params.get("embeddings", False), + "includeMetadata": _params.get("metadata_fields", False), + "includeDataRowDetails": _params.get( + "data_row_details", False + ), + "includeProjectDetails": _params.get( + "project_details", False + ), + "includePerformanceDetails": _params.get( + "performance_details", False + ), + "includeLabelDetails": _params.get("label_details", False), + "includeInterpolatedFrames": _params.get( + "interpolated_frames", False + ), + "includePredictions": _params.get("predictions", False), + "projectIds": _params.get("project_ids", None), + "modelRunIds": _params.get("model_run_ids", None), + "allProjects": _params.get("all_projects", False), + "allModelRuns": _params.get("all_model_runs", False), }, "streamable": streamable, } } search_query = build_filters(self.client, _filters) - search_query.append({ - "ids": [self.uid], - "operator": "is", - "type": "dataset" - }) + search_query.append( + {"ids": [self.uid], "operator": "is", "type": "dataset"} + ) query_params["input"]["filters"]["searchQuery"]["query"] = search_query - res = self.client.execute(create_task_query_str, - query_params, - error_log_key="errors") + res = self.client.execute( + create_task_query_str, query_params, error_log_key="errors" + ) res = res[mutation_name] task_id = res["taskId"] is_streamable = res["isStreamable"] return Task.get_task(self.client, task_id), is_streamable - def upsert_data_rows(self, - items, - file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT - ) -> "DataUpsertTask": + def upsert_data_rows( + self, items, file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT + ) -> "DataUpsertTask": """ Upserts data rows in this dataset. When "key" is provided, and it references an existing data row, an update will be performed. When "key" is not provided a new data row will be created. @@ -585,19 +612,19 @@ def upsert_data_rows(self, def _exec_upsert_data_rows( self, specs: List[DataRowItemBase], - file_upload_thread_count: int = FILE_UPLOAD_THREAD_COUNT + file_upload_thread_count: int = FILE_UPLOAD_THREAD_COUNT, ) -> "DataUpsertTask": - manifest = data_row_uploader.upload_in_chunks( client=self.client, specs=specs, file_upload_thread_count=file_upload_thread_count, - max_chunk_size_bytes=UPSERT_CHUNK_SIZE_BYTES) + max_chunk_size_bytes=UPSERT_CHUNK_SIZE_BYTES, + ) data = json.dumps(manifest.model_dump()).encode("utf-8") - manifest_uri = self.client.upload_data(data, - content_type="application/json", - filename="manifest.json") + manifest_uri = self.client.upload_data( + data, content_type="application/json", filename="manifest.json" + ) query_str = """ mutation UpsertDataRowsPyApi($manifestUri: String!) { @@ -614,44 +641,47 @@ def _exec_upsert_data_rows( return task def add_iam_integration( - self, iam_integration: Union[str, - IAMIntegration]) -> IAMIntegration: - """ - Sets the IAM integration for the dataset. IAM integration is used to sign URLs for data row assets. - - Args: - iam_integration (Union[str, IAMIntegration]): IAM integration object or IAM integration id. - - Returns: - IAMIntegration: IAM integration object. - - Raises: - LabelboxError: If the IAM integration can't be set. + self, iam_integration: Union[str, IAMIntegration] + ) -> IAMIntegration: + """ + Sets the IAM integration for the dataset. IAM integration is used to sign URLs for data row assets. - Examples: - - >>> # Get all IAM integrations - >>> iam_integrations = client.get_organization().get_iam_integrations() - >>> - >>> # Get IAM integration id - >>> iam_integration_id = [integration.uid for integration - >>> in iam_integrations - >>> if integration.name == "My S3 integration"][0] - >>> - >>> # Set IAM integration for integration id - >>> dataset.set_iam_integration(iam_integration_id) - >>> - >>> # Get IAM integration object - >>> iam_integration = [integration.uid for integration - >>> in iam_integrations - >>> if integration.name == "My S3 integration"][0] - >>> - >>> # Set IAM integration for IAMIntegrtion object - >>> dataset.set_iam_integration(iam_integration) + Args: + iam_integration (Union[str, IAMIntegration]): IAM integration object or IAM integration id. + + Returns: + IAMIntegration: IAM integration object. + + Raises: + LabelboxError: If the IAM integration can't be set. + + Examples: + + >>> # Get all IAM integrations + >>> iam_integrations = client.get_organization().get_iam_integrations() + >>> + >>> # Get IAM integration id + >>> iam_integration_id = [integration.uid for integration + >>> in iam_integrations + >>> if integration.name == "My S3 integration"][0] + >>> + >>> # Set IAM integration for integration id + >>> dataset.set_iam_integration(iam_integration_id) + >>> + >>> # Get IAM integration object + >>> iam_integration = [integration.uid for integration + >>> in iam_integrations + >>> if integration.name == "My S3 integration"][0] + >>> + >>> # Set IAM integration for IAMIntegrtion object + >>> dataset.set_iam_integration(iam_integration) """ - iam_integration_id = iam_integration.uid if isinstance( - iam_integration, IAMIntegration) else iam_integration + iam_integration_id = ( + iam_integration.uid + if isinstance(iam_integration, IAMIntegration) + else iam_integration + ) query = """ mutation SetSignerForDatasetPyApi($signerId: ID!, $datasetId: ID!) { @@ -667,29 +697,30 @@ def add_iam_integration( } """ - response = self.client.execute(query, { - "signerId": iam_integration_id, - "datasetId": self.uid - }) + response = self.client.execute( + query, {"signerId": iam_integration_id, "datasetId": self.uid} + ) if not response: - raise ResourceNotFoundError(IAMIntegration, { - "signerId": iam_integration_id, - "datasetId": self.uid - }) + raise ResourceNotFoundError( + IAMIntegration, + {"signerId": iam_integration_id, "datasetId": self.uid}, + ) try: - iam_integration_id = response.get("setSignerForDataset", - {}).get("signer", {})["id"] + iam_integration_id = response.get("setSignerForDataset", {}).get( + "signer", {} + )["id"] return [ - integration for integration in - self.client.get_organization().get_iam_integrations() + integration + for integration in self.client.get_organization().get_iam_integrations() if integration.uid == iam_integration_id ][0] except: raise LabelboxError( - f"Can't retrieve IAM integration {iam_integration_id}") + f"Can't retrieve IAM integration {iam_integration_id}" + ) def remove_iam_integration(self) -> None: """ diff --git a/libs/labelbox/src/labelbox/schema/embedding.py b/libs/labelbox/src/labelbox/schema/embedding.py index a67b82d38..dd5224c7e 100644 --- a/libs/labelbox/src/labelbox/schema/embedding.py +++ b/libs/labelbox/src/labelbox/schema/embedding.py @@ -13,6 +13,7 @@ class EmbeddingVector(BaseModel): vector (list): The raw vector values - the number of entries should match the Embedding's dimensions clusters (list): The cluster groupings """ + embedding_id: str vector: List[float] clusters: Optional[List[int]] = None @@ -37,6 +38,7 @@ class Embedding(BaseModel): dims (int): Refers to the size of the vector space in which words, phrases, or other entities are embedded custom (bool): Indicates whether the embedding is a Precomputed embedding or a Custom embedding """ + id: str name: str custom: bool @@ -54,10 +56,11 @@ def delete(self): """ self._client.delete_embedding(self.id) - def import_vectors_from_file(self, - path: str, - callback: Optional[Callable[[Dict[str, Any]], - None]] = None): + def import_vectors_from_file( + self, + path: str, + callback: Optional[Callable[[Dict[str, Any]], None]] = None, + ): """ Import vectors into a given embedding from an NDJSON file. An NDJSON file consists of newline delimited JSON. Each line of the file diff --git a/libs/labelbox/src/labelbox/schema/enums.py b/libs/labelbox/src/labelbox/schema/enums.py index c08e91bfa..6f8aebc58 100644 --- a/libs/labelbox/src/labelbox/schema/enums.py +++ b/libs/labelbox/src/labelbox/schema/enums.py @@ -2,7 +2,7 @@ class BulkImportRequestState(Enum): - """ State of the import job when importing annotations (RUNNING, FAILED, or FINISHED). + """State of the import job when importing annotations (RUNNING, FAILED, or FINISHED). If you are not usinig MEA continue using BulkImportRequest. AnnotationImports are in beta and will change soon. @@ -20,13 +20,14 @@ class BulkImportRequestState(Enum): * - FINISHED - Indicates the import job is no longer running. Check `BulkImportRequest.statuses` for more information """ + RUNNING = "RUNNING" FAILED = "FAILED" FINISHED = "FINISHED" class AnnotationImportState(Enum): - """ State of the import job when importing annotations (RUNNING, FAILED, or FINISHED). + """State of the import job when importing annotations (RUNNING, FAILED, or FINISHED). .. list-table:: :widths: 15 150 @@ -41,23 +42,25 @@ class AnnotationImportState(Enum): * - FINISHED - Indicates the import job is no longer running. Check `AnnotationImport.statuses` for more information """ + RUNNING = "RUNNING" FAILED = "FAILED" FINISHED = "FINISHED" class CollectionJobStatus(Enum): - """ Status of an asynchronous job over a collection. - - * - State - - Description - * - SUCCESS - - Indicates job has successfully processed entire collection of data - * - PARTIAL SUCCESS - - Indicates some data in the collection has succeeded and other data have failed - * - FAILURE - - Indicates job has failed to process entire collection of data + """Status of an asynchronous job over a collection. + + * - State + - Description + * - SUCCESS + - Indicates job has successfully processed entire collection of data + * - PARTIAL SUCCESS + - Indicates some data in the collection has succeeded and other data have failed + * - FAILURE + - Indicates job has failed to process entire collection of data """ + SUCCESS = "SUCCESS" PARTIAL_SUCCESS = "PARTIAL SUCCESS" - FAILURE = "FAILURE" \ No newline at end of file + FAILURE = "FAILURE" diff --git a/libs/labelbox/src/labelbox/schema/export_filters.py b/libs/labelbox/src/labelbox/schema/export_filters.py index aa97cbced..641adc011 100644 --- a/libs/labelbox/src/labelbox/schema/export_filters.py +++ b/libs/labelbox/src/labelbox/schema/export_filters.py @@ -3,6 +3,7 @@ from datetime import datetime, timezone from typing import Collection, Dict, Tuple, List, Optional from labelbox.typing_imports import Literal + if sys.version_info >= (3, 8): from typing import TypedDict else: @@ -47,8 +48,9 @@ class ProjectExportFilters(SharedExportFilters): Example: >>> ["clgo3lyax0000veeezdbu3ws4"] """ - workflow_status: Optional[Literal["ToLabel", "InReview", "InRework", - "Done"]] + workflow_status: Optional[ + Literal["ToLabel", "InReview", "InRework", "Done"] + ] """ Export data rows matching workflow status Example: >>> "InReview" @@ -68,7 +70,7 @@ class DatarowExportFilters(BaseExportFilters): def validate_datetime(datetime_str: str) -> bool: - """helper function to validate that datetime's format: "YYYY-MM-DD" or "YYYY-MM-DD hh:mm:ss" + """helper function to validate that datetime's format: "YYYY-MM-DD" or "YYYY-MM-DD hh:mm:ss" or ISO 8061 format "YYYY-MM-DDThh:mm:ss±hhmm" (Example: "2023-05-23T14:30:00+0530")""" if datetime_str: for fmt in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", ISO_8061_FORMAT): @@ -78,8 +80,7 @@ def validate_datetime(datetime_str: str) -> bool: except ValueError: pass raise ValueError(f"""Incorrect format for: {datetime_str}. - Format must be \"YYYY-MM-DD\" or \"YYYY-MM-DD hh:mm:ss\" or ISO 8061 format \"YYYY-MM-DDThh:mm:ss±hhmm\"""" - ) + Format must be \"YYYY-MM-DD\" or \"YYYY-MM-DD hh:mm:ss\" or ISO 8061 format \"YYYY-MM-DDThh:mm:ss±hhmm\"""") return True @@ -96,8 +97,10 @@ def convert_to_utc_if_iso8061(datetime_str: str, timezone_str: Optional[str]): def validate_one_of_data_row_ids_or_global_keys(filters): - if filters.get("data_row_ids") is not None and filters.get( - "global_keys") is not None: + if ( + filters.get("data_row_ids") is not None + and filters.get("global_keys") is not None + ): raise ValueError( "data_rows and global_keys cannot both be present in export filters" ) @@ -117,9 +120,11 @@ def _get_timezone() -> str: tz_res = client.execute(timezone_query_str) return tz_res["user"]["timezone"] or "UTC" - def _build_id_filters(ids: list, - type_name: str, - search_where_limit: int = SEARCH_LIMIT_PER_EXPORT_V2): + def _build_id_filters( + ids: list, + type_name: str, + search_where_limit: int = SEARCH_LIMIT_PER_EXPORT_V2, + ): if not isinstance(ids, list): raise ValueError(f"{type_name} filter expects a list.") if len(ids) == 0: @@ -136,85 +141,91 @@ def _build_id_filters(ids: list, if last_activity_at: timezone = _get_timezone() start, end = last_activity_at - if (start is not None and end is not None): + if start is not None and end is not None: [validate_datetime(date) for date in last_activity_at] start, timezone = convert_to_utc_if_iso8061(start, timezone) end, timezone = convert_to_utc_if_iso8061(end, timezone) - search_query.append({ - "type": "data_row_last_activity_at", - "value": { - "operator": "BETWEEN", - "timezone": timezone, + search_query.append( + { + "type": "data_row_last_activity_at", "value": { - "min": start, - "max": end - } + "operator": "BETWEEN", + "timezone": timezone, + "value": {"min": start, "max": end}, + }, } - }) - elif (start is not None): + ) + elif start is not None: validate_datetime(start) start, timezone = convert_to_utc_if_iso8061(start, timezone) - search_query.append({ - "type": "data_row_last_activity_at", - "value": { - "operator": "GREATER_THAN_OR_EQUAL", - "timezone": timezone, - "value": start + search_query.append( + { + "type": "data_row_last_activity_at", + "value": { + "operator": "GREATER_THAN_OR_EQUAL", + "timezone": timezone, + "value": start, + }, } - }) - elif (end is not None): + ) + elif end is not None: validate_datetime(end) end, timezone = convert_to_utc_if_iso8061(end, timezone) - search_query.append({ - "type": "data_row_last_activity_at", - "value": { - "operator": "LESS_THAN_OR_EQUAL", - "timezone": timezone, - "value": end + search_query.append( + { + "type": "data_row_last_activity_at", + "value": { + "operator": "LESS_THAN_OR_EQUAL", + "timezone": timezone, + "value": end, + }, } - }) + ) label_created_at = filters.get("label_created_at") if label_created_at: timezone = _get_timezone() start, end = label_created_at - if (start is not None and end is not None): + if start is not None and end is not None: [validate_datetime(date) for date in label_created_at] start, timezone = convert_to_utc_if_iso8061(start, timezone) end, timezone = convert_to_utc_if_iso8061(end, timezone) - search_query.append({ - "type": "labeled_at", - "value": { - "operator": "BETWEEN", - "timezone": timezone, + search_query.append( + { + "type": "labeled_at", "value": { - "min": start, - "max": end - } + "operator": "BETWEEN", + "timezone": timezone, + "value": {"min": start, "max": end}, + }, } - }) - elif (start is not None): + ) + elif start is not None: validate_datetime(start) start, timezone = convert_to_utc_if_iso8061(start, timezone) - search_query.append({ - "type": "labeled_at", - "value": { - "operator": "GREATER_THAN_OR_EQUAL", - "timezone": timezone, - "value": start + search_query.append( + { + "type": "labeled_at", + "value": { + "operator": "GREATER_THAN_OR_EQUAL", + "timezone": timezone, + "value": start, + }, } - }) - elif (end is not None): + ) + elif end is not None: validate_datetime(end) end, timezone = convert_to_utc_if_iso8061(end, timezone) - search_query.append({ - "type": "labeled_at", - "value": { - "operator": "LESS_THAN_OR_EQUAL", - "timezone": timezone, - "value": end + search_query.append( + { + "type": "labeled_at", + "value": { + "operator": "LESS_THAN_OR_EQUAL", + "timezone": timezone, + "value": end, + }, } - }) + ) data_row_ids = filters.get("data_row_ids") if data_row_ids is not None: @@ -240,9 +251,8 @@ def _build_id_filters(ids: list, if workflow_status == "ToLabel": search_query.append({"type": "task_queue_not_exist"}) else: - search_query.append({ - "type": 'task_queue_status', - "status": workflow_status - }) + search_query.append( + {"type": "task_queue_status", "status": workflow_status} + ) return search_query diff --git a/libs/labelbox/src/labelbox/schema/export_params.py b/libs/labelbox/src/labelbox/schema/export_params.py index 5229e2bfa..b15bc2828 100644 --- a/libs/labelbox/src/labelbox/schema/export_params.py +++ b/libs/labelbox/src/labelbox/schema/export_params.py @@ -5,6 +5,7 @@ EXPORT_LIMIT = 30 from labelbox.schema.media_type import MediaType + if sys.version_info >= (3, 8): from typing import TypedDict else: @@ -49,9 +50,11 @@ def _validate_array_length(array, max_length, array_name): def validate_catalog_export_params(params: CatalogExportParams): if "model_run_ids" in params and params["model_run_ids"] is not None: - _validate_array_length(params["model_run_ids"], EXPORT_LIMIT, - "model_run_ids") + _validate_array_length( + params["model_run_ids"], EXPORT_LIMIT, "model_run_ids" + ) if "project_ids" in params and params["project_ids"] is not None: - _validate_array_length(params["project_ids"], EXPORT_LIMIT, - "project_ids") + _validate_array_length( + params["project_ids"], EXPORT_LIMIT, "project_ids" + ) diff --git a/libs/labelbox/src/labelbox/schema/export_task.py b/libs/labelbox/src/labelbox/schema/export_task.py index 423e66ceb..a144f4c76 100644 --- a/libs/labelbox/src/labelbox/schema/export_task.py +++ b/libs/labelbox/src/labelbox/schema/export_task.py @@ -111,7 +111,7 @@ class JsonConverterOutput: class JsonConverter(Converter[JsonConverterOutput]): # pylint: disable=too-few-public-methods """Converts JSON data. - + Deprecated: This converter is deprecated and will be removed in a future release. """ @@ -133,16 +133,21 @@ def _find_json_object_offsets(self, data: str) -> List[Tuple[int, int]]: current_object_start = index # we need to account for scenarios where data lands in the middle of an object # and the object is not the last one in the data - if index > 0 and data[index - - 1] == "\n" and not object_offsets: + if ( + index > 0 + and data[index - 1] == "\n" + and not object_offsets + ): object_offsets.append((0, index - 1)) elif char == "}" and stack: stack.pop() # this covers cases where the last object is either followed by a newline or # it is missing - if len(stack) == 0 and (len(data) == index + 1 or - data[index + 1] == "\n" - ) and current_object_start is not None: + if ( + len(stack) == 0 + and (len(data) == index + 1 or data[index + 1] == "\n") + and current_object_start is not None + ): object_offsets.append((current_object_start, index + 1)) current_object_start = None @@ -162,7 +167,7 @@ def convert( yield JsonConverterOutput( current_offset=current_offset + offset_start, current_line=current_line + line, - json_str=raw_data[offset_start:offset_end + 1].strip(), + json_str=raw_data[offset_start : offset_end + 1].strip(), ) @@ -179,8 +184,7 @@ class FileConverterOutput: class FileConverter(Converter[FileConverterOutput]): - """Converts data to a file. - """ + """Converts data to a file.""" def __init__(self, file_path: str) -> None: super().__init__() @@ -224,8 +228,8 @@ def get_next_chunk(self) -> Optional[Tuple[_MetadataFileInfo, str]]: """Retrieves the file.""" def _get_file_content( - self, query: str, variables: dict, - result_field_name: str) -> Tuple[_MetadataFileInfo, str]: + self, query: str, variables: dict, result_field_name: str + ) -> Tuple[_MetadataFileInfo, str]: """Runs the query.""" res = self._ctx.client.execute(query, variables, error_log_key="errors") res = res["task"][result_field_name] @@ -233,14 +237,17 @@ def _get_file_content( if not file_info: raise ValueError( f"Task {self._ctx.task_id} does not have a metadata file for the " - f"{self._ctx.stream_type.value} stream") + f"{self._ctx.stream_type.value} stream" + ) response = requests.get(file_info.file, timeout=30) response.raise_for_status() - assert len( - response.content - ) == file_info.offsets.end - file_info.offsets.start + 1, ( + assert ( + len(response.content) + == file_info.offsets.end - file_info.offsets.start + 1 + ), ( f"expected {file_info.offsets.end - file_info.offsets.start + 1} bytes, " - f"got {len(response.content)} bytes") + f"got {len(response.content)} bytes" + ) return file_info, response.text @@ -260,8 +267,9 @@ def __init__( f"offset is out of range, max offset is {self._ctx.metadata_header.total_size - 1}" ) - def _find_line_at_offset(self, file_content: str, - target_offset: int) -> int: + def _find_line_at_offset( + self, file_content: str, target_offset: int + ) -> int: # TODO: Remove this, incorrect parsing of JSON to find braces stack = [] line_number = 0 @@ -288,22 +296,24 @@ def get_next_chunk(self) -> Optional[Tuple[_MetadataFileInfo, str]]: f"{{task(where: $where)" f"{{{'exportFileFromOffset'}(streamType: $streamType, offset: $offset)" f"{{offsets {{start end}} lines {{start end}} file}}" - f"}}}}") + f"}}}}" + ) variables = { - "where": { - "id": self._ctx.task_id - }, + "where": {"id": self._ctx.task_id}, "streamType": self._ctx.stream_type.value, "offset": str(self._current_offset), } file_info, file_content = self._get_file_content( - query, variables, "exportFileFromOffset") + query, variables, "exportFileFromOffset" + ) if self._current_line is None: self._current_line = self._find_line_at_offset( - file_content, self._current_offset - file_info.offsets.start) + file_content, self._current_offset - file_info.offsets.start + ) self._current_line += file_info.lines.start - file_content = file_content[self._current_offset - - file_info.offsets.start:] + file_content = file_content[ + self._current_offset - file_info.offsets.start : + ] file_info.offsets.start = self._current_offset file_info.lines.start = self._current_line self._current_offset = file_info.offsets.end + 1 @@ -357,22 +367,24 @@ def get_next_chunk(self) -> Optional[Tuple[_MetadataFileInfo, str]]: f"{{task(where: $where)" f"{{{'exportFileFromLine'}(streamType: $streamType, line: $line)" f"{{offsets {{start end}} lines {{start end}} file}}" - f"}}}}") + f"}}}}" + ) variables = { - "where": { - "id": self._ctx.task_id - }, + "where": {"id": self._ctx.task_id}, "streamType": self._ctx.stream_type.value, "line": self._current_line, } file_info, file_content = self._get_file_content( - query, variables, "exportFileFromLine") + query, variables, "exportFileFromLine" + ) if self._current_offset is None: self._current_offset = self._find_offset_of_line( - file_content, self._current_line - file_info.lines.start) + file_content, self._current_line - file_info.lines.start + ) self._current_offset += file_info.offsets.start - file_content = file_content[self._current_offset - - file_info.offsets.start:] + file_content = file_content[ + self._current_offset - file_info.offsets.start : + ] file_info.offsets.start = self._current_offset file_info.lines.start = self._current_line self._current_offset = file_info.offsets.end + 1 @@ -394,7 +406,7 @@ def read(self) -> Iterator[Tuple[_MetadataFileInfo, str]]: class _MultiGCSFileReader(_Reader): # pylint: disable=too-few-public-methods """Reads data from multiple GCS files in a seamless way. - + Deprecated: This reader is deprecated and will be removed in a future release. """ @@ -437,7 +449,9 @@ def __init__( def __iter__(self): yield from self._fetch() - def _fetch(self,) -> Iterator[OutputT]: + def _fetch( + self, + ) -> Iterator[OutputT]: """Fetches the result data. Returns an iterator that yields the offset and the data. """ @@ -448,25 +462,27 @@ def _fetch(self,) -> Iterator[OutputT]: with self._converter as converter: for file_info, raw_data in stream: for output in converter.convert( - Converter.ConverterInputArgs(self._ctx, file_info, - raw_data)): + Converter.ConverterInputArgs(self._ctx, file_info, raw_data) + ): yield output def with_offset(self, offset: int) -> "Stream[OutputT]": """Sets the offset for the stream.""" self._reader.set_retrieval_strategy( - FileRetrieverByOffset(self._ctx, offset)) + FileRetrieverByOffset(self._ctx, offset) + ) return self def with_line(self, line: int) -> "Stream[OutputT]": """Sets the line number for the stream.""" - self._reader.set_retrieval_strategy(FileRetrieverByLine( - self._ctx, line)) + self._reader.set_retrieval_strategy( + FileRetrieverByLine(self._ctx, line) + ) return self def start( - self, - stream_handler: Optional[Callable[[OutputT], None]] = None) -> None: + self, stream_handler: Optional[Callable[[OutputT], None]] = None + ) -> None: """Starts streaming the result data. Calls the stream_handler for each result. """ @@ -501,16 +517,16 @@ def get_next_chunk(self) -> Optional[Tuple[_MetadataFileInfo, str]]: f"{{task(where: $where)" f"{{{'exportFileFromOffset'}(streamType: $streamType, offset: $offset)" f"{{offsets {{start end}} lines {{start end}} file}}" - f"}}}}") + f"}}}}" + ) variables = { - "where": { - "id": self._ctx.task_id - }, + "where": {"id": self._ctx.task_id}, "streamType": self._ctx.stream_type.value, "offset": str(self._current_offset), } file_info, file_content = self._get_file_content( - query, variables, "exportFileFromOffset") + query, variables, "exportFileFromOffset" + ) file_info.offsets.start = self._current_offset file_info.lines.start = self._current_line self._current_offset = file_info.offsets.end + 1 @@ -529,12 +545,15 @@ def __init__( self._reader = _BufferedGCSFileReader() self._converter = _BufferedJsonConverter() self._reader.set_retrieval_strategy( - _BufferedFileRetrieverByOffset(self._ctx, 0)) + _BufferedFileRetrieverByOffset(self._ctx, 0) + ) def __iter__(self): yield from self._fetch() - def _fetch(self,) -> Iterator[OutputT]: + def _fetch( + self, + ) -> Iterator[OutputT]: """Fetches the result data. Returns an iterator that yields the offset and the data. """ @@ -545,13 +564,13 @@ def _fetch(self,) -> Iterator[OutputT]: with self._converter as converter: for file_info, raw_data in stream: for output in converter.convert( - Converter.ConverterInputArgs(self._ctx, file_info, - raw_data)): + Converter.ConverterInputArgs(self._ctx, file_info, raw_data) + ): yield output def start( - self, - stream_handler: Optional[Callable[[OutputT], None]] = None) -> None: + self, stream_handler: Optional[Callable[[OutputT], None]] = None + ) -> None: """Starts streaming the result data. Calls the stream_handler for each result. """ @@ -564,12 +583,12 @@ def start( @dataclass class BufferedJsonConverterOutput: """Output with the JSON object""" + json: Any class _BufferedJsonConverter(Converter[BufferedJsonConverterOutput]): - """Converts JSON data in a buffered manner - """ + """Converts JSON data in a buffered manner""" def convert( self, input_args: Converter.ConverterInputArgs @@ -592,7 +611,7 @@ def read(self) -> Iterator[Tuple[_MetadataFileInfo, str]]: if not self._retrieval_strategy: raise ValueError("retrieval strategy not set") # create a buffer - with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp_file: + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file: result = self._retrieval_strategy.get_next_chunk() while result: _, raw_data = result @@ -604,12 +623,16 @@ def read(self) -> Iterator[Tuple[_MetadataFileInfo, str]]: temp_file.write(raw_data) result = self._retrieval_strategy.get_next_chunk() # read buffer - with open(temp_file.name, 'r') as temp_file_reopened: + with open(temp_file.name, "r") as temp_file_reopened: for idx, line in enumerate(temp_file_reopened): - yield _MetadataFileInfo(offsets=Range(start=0, - end=len(line) - 1), - lines=Range(start=idx, end=idx + 1), - file=temp_file.name), line + yield ( + _MetadataFileInfo( + offsets=Range(start=0, end=len(line) - 1), + lines=Range(start=idx, end=idx + 1), + file=temp_file.name, + ), + line, + ) # manually delete buffer os.unlink(temp_file.name) @@ -632,8 +655,11 @@ def __init__(self, task: Task, is_export_v2: bool = False) -> None: self._task = task def __repr__(self): - return f"" if getattr( - self, "uid", None) else "" + return ( + f"" + if getattr(self, "uid", None) + else "" + ) def __str__(self): properties_to_include = [ @@ -702,8 +728,13 @@ def result_url(self): "This property is only available for export_v2 tasks due to compatibility reasons, please use streamable errors instead" ) base_url = self._task.client.rest_endpoint - return base_url + '/export-results/' + self._task.uid + '/' + self._task.client.get_organization( - ).uid + return ( + base_url + + "/export-results/" + + self._task.uid + + "/" + + self._task.client.get_organization().uid + ) @property def errors_url(self): @@ -715,8 +746,13 @@ def errors_url(self): if not self.has_errors(): return None base_url = self._task.client.rest_endpoint - return base_url + '/export-errors/' + self._task.uid + '/' + self._task.client.get_organization( - ).uid + return ( + base_url + + "/export-errors/" + + self._task.uid + + "/" + + self._task.client.get_organization().uid + ) @property def errors(self): @@ -736,14 +772,18 @@ def errors(self): data = [] metadata_header = ExportTask._get_metadata_header( - self._task.client, self._task.uid, StreamType.ERRORS) + self._task.client, self._task.uid, StreamType.ERRORS + ) if metadata_header is None: return None BufferedStream( _TaskContext( - self._task.client, self._task.uid, StreamType.ERRORS, - metadata_header),).start( - stream_handler=lambda output: data.append(output.json)) + self._task.client, + self._task.uid, + StreamType.ERRORS, + metadata_header, + ), + ).start(stream_handler=lambda output: data.append(output.json)) return data @property @@ -757,14 +797,18 @@ def result(self): data = [] metadata_header = ExportTask._get_metadata_header( - self._task.client, self._task.uid, StreamType.RESULT) + self._task.client, self._task.uid, StreamType.RESULT + ) if metadata_header is None: return [] BufferedStream( _TaskContext( - self._task.client, self._task.uid, StreamType.RESULT, - metadata_header),).start( - stream_handler=lambda output: data.append(output.json)) + self._task.client, + self._task.uid, + StreamType.RESULT, + metadata_header, + ), + ).start(stream_handler=lambda output: data.append(output.json)) return data return self._task.result_url @@ -798,15 +842,17 @@ def wait_till_done(self, timeout_seconds: int = 7200) -> None: @staticmethod @lru_cache(maxsize=5) def _get_metadata_header( - client, task_id: str, - stream_type: StreamType) -> Union[_MetadataHeader, None]: + client, task_id: str, stream_type: StreamType + ) -> Union[_MetadataHeader, None]: """Returns the total file size for a specific task.""" - query = (f"query GetExportMetadataHeaderPyApi" - f"($where: WhereUniqueIdInput, $streamType: TaskStreamType!)" - f"{{task(where: $where)" - f"{{{'exportMetadataHeader'}(streamType: $streamType)" - f"{{totalSize totalLines}}" - f"}}}}") + query = ( + f"query GetExportMetadataHeaderPyApi" + f"($where: WhereUniqueIdInput, $streamType: TaskStreamType!)" + f"{{task(where: $where)" + f"{{{'exportMetadataHeader'}(streamType: $streamType)" + f"{{totalSize totalLines}}" + f"}}}}" + ) variables = {"where": {"id": task_id}, "streamType": stream_type.value} res = client.execute(query, variables, error_log_key="errors") res = res["task"]["exportMetadataHeader"] @@ -818,8 +864,9 @@ def get_total_file_size(self, stream_type: StreamType) -> Union[int, None]: raise ExportTask.ExportTaskException("Task failed") if self._task.status != "COMPLETE": raise ExportTask.ExportTaskException("Task is not ready yet") - header = ExportTask._get_metadata_header(self._task.client, - self._task.uid, stream_type) + header = ExportTask._get_metadata_header( + self._task.client, self._task.uid, stream_type + ) return header.total_size if header else None def get_total_lines(self, stream_type: StreamType) -> Union[int, None]: @@ -828,8 +875,9 @@ def get_total_lines(self, stream_type: StreamType) -> Union[int, None]: raise ExportTask.ExportTaskException("Task failed") if self._task.status != "COMPLETE": raise ExportTask.ExportTaskException("Task is not ready yet") - header = ExportTask._get_metadata_header(self._task.client, - self._task.uid, stream_type) + header = ExportTask._get_metadata_header( + self._task.client, self._task.uid, stream_type + ) return header.total_lines if header else None def has_result(self) -> bool: @@ -864,15 +912,18 @@ def get_buffered_stream( if self._task.status != "COMPLETE": raise ExportTask.ExportTaskException("Task is not ready yet") - metadata_header = self._get_metadata_header(self._task.client, - self._task.uid, stream_type) + metadata_header = self._get_metadata_header( + self._task.client, self._task.uid, stream_type + ) if metadata_header is None: raise ValueError( f"Task {self._task.uid} does not have a {stream_type.value} stream" ) return BufferedStream( - _TaskContext(self._task.client, self._task.uid, stream_type, - metadata_header),) + _TaskContext( + self._task.client, self._task.uid, stream_type, metadata_header + ), + ) @overload def get_stream( @@ -906,15 +957,17 @@ def get_stream( if self._task.status != "COMPLETE": raise ExportTask.ExportTaskException("Task is not ready yet") - metadata_header = self._get_metadata_header(self._task.client, - self._task.uid, stream_type) + metadata_header = self._get_metadata_header( + self._task.client, self._task.uid, stream_type + ) if metadata_header is None: raise ValueError( f"Task {self._task.uid} does not have a {stream_type.value} stream" ) return Stream( - _TaskContext(self._task.client, self._task.uid, stream_type, - metadata_header), + _TaskContext( + self._task.client, self._task.uid, stream_type, metadata_header + ), _MultiGCSFileReader(), converter, ) @@ -923,4 +976,3 @@ def get_stream( def get_task(client, task_id): """Returns the task with the given id.""" return ExportTask(Task.get_task(client, task_id)) - \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/schema/foundry/app.py b/libs/labelbox/src/labelbox/schema/foundry/app.py index f73d5056f..2886dec15 100644 --- a/libs/labelbox/src/labelbox/schema/foundry/app.py +++ b/libs/labelbox/src/labelbox/schema/foundry/app.py @@ -13,7 +13,7 @@ class App(_CamelCaseMixin): class_to_schema_id: Dict[str, str] ontology_id: str created_by: Optional[str] = None - + model_config = ConfigDict(protected_namespaces=()) @classmethod @@ -21,4 +21,4 @@ def type_name(cls): return "App" -APP_FIELD_NAMES = list(App.model_json_schema()['properties'].keys()) +APP_FIELD_NAMES = list(App.model_json_schema()["properties"].keys()) diff --git a/libs/labelbox/src/labelbox/schema/foundry/foundry_client.py b/libs/labelbox/src/labelbox/schema/foundry/foundry_client.py index 27d577bc0..914a363c7 100644 --- a/libs/labelbox/src/labelbox/schema/foundry/foundry_client.py +++ b/libs/labelbox/src/labelbox/schema/foundry/foundry_client.py @@ -6,7 +6,6 @@ class FoundryClient: - def __init__(self, client): self.client = client @@ -35,7 +34,7 @@ def _create_app(self, app: App) -> App: try: response = self.client.execute(query_str, params) except exceptions.LabelboxError as e: - raise exceptions.LabelboxError('Unable to create app', e) + raise exceptions.LabelboxError("Unable to create app", e) return App(**response["createModelFoundryApp"]) def _get_app(self, id: str) -> App: @@ -55,7 +54,7 @@ def _get_app(self, id: str) -> App: except exceptions.InvalidQueryError as e: raise exceptions.ResourceNotFoundError(App, params) except Exception as e: - raise exceptions.LabelboxError(f'Unable to get app with id {id}', e) + raise exceptions.LabelboxError(f"Unable to get app with id {id}", e) return App(**response["findModelFoundryApp"]) def _delete_app(self, id: str) -> None: @@ -70,11 +69,16 @@ def _delete_app(self, id: str) -> None: try: self.client.execute(query_str, params) except Exception as e: - raise exceptions.LabelboxError(f'Unable to delete app with id {id}', - e) + raise exceptions.LabelboxError( + f"Unable to delete app with id {id}", e + ) - def run_app(self, model_run_name: str, - data_rows: Union[DataRowIds, GlobalKeys], app_id: str) -> Task: + def run_app( + self, + model_run_name: str, + data_rows: Union[DataRowIds, GlobalKeys], + app_id: str, + ) -> Task: app = self._get_app(app_id) params = { @@ -82,10 +86,14 @@ def run_app(self, model_run_name: str, "name": model_run_name, "classToSchemaId": app.class_to_schema_id, "inferenceParams": app.inference_params, - "ontologyId": app.ontology_id + "ontologyId": app.ontology_id, } - data_rows_key = "dataRowIds" if data_rows.id_type == IdType.DataRowId else "globalKeys" + data_rows_key = ( + "dataRowIds" + if data_rows.id_type == IdType.DataRowId + else "globalKeys" + ) params[data_rows_key] = list(data_rows) query = """ @@ -99,6 +107,6 @@ def run_app(self, model_run_name: str, try: response = self.client.execute(query, {"input": params}) except Exception as e: - raise exceptions.LabelboxError('Unable to run foundry app', e) + raise exceptions.LabelboxError("Unable to run foundry app", e) task_id = response["createModelJobForDataRows"]["taskId"] return Task.get_task(self.client, task_id) diff --git a/libs/labelbox/src/labelbox/schema/foundry/model.py b/libs/labelbox/src/labelbox/schema/foundry/model.py index 87fda22f2..6c2ab6d88 100644 --- a/libs/labelbox/src/labelbox/schema/foundry/model.py +++ b/libs/labelbox/src/labelbox/schema/foundry/model.py @@ -15,4 +15,4 @@ class Model(_CamelCaseMixin, BaseModel): created_at: datetime -MODEL_FIELD_NAMES = list(Model.model_json_schema()['properties'].keys()) +MODEL_FIELD_NAMES = list(Model.model_json_schema()["properties"].keys()) diff --git a/libs/labelbox/src/labelbox/schema/iam_integration.py b/libs/labelbox/src/labelbox/schema/iam_integration.py index 00c4f0ae9..cb5309929 100644 --- a/libs/labelbox/src/labelbox/schema/iam_integration.py +++ b/libs/labelbox/src/labelbox/schema/iam_integration.py @@ -17,7 +17,7 @@ class GcpIamIntegrationSettings: class IAMIntegration(DbObject): - """ Represents an IAM integration for delegated access + """Represents an IAM integration for delegated access Attributes: name (str) @@ -31,9 +31,9 @@ class IAMIntegration(DbObject): """ def __init__(self, client, data): - settings = data.pop('settings', None) + settings = data.pop("settings", None) if settings is not None: - type_name = settings.pop('__typename') + type_name = settings.pop("__typename") settings = {snake_case(k): v for k, v in settings.items()} if type_name == "GcpIamIntegrationSettings": self.settings = GcpIamIntegrationSettings(**settings) diff --git a/libs/labelbox/src/labelbox/schema/id_type.py b/libs/labelbox/src/labelbox/schema/id_type.py index a78dc572c..3ecad4ca1 100644 --- a/libs/labelbox/src/labelbox/schema/id_type.py +++ b/libs/labelbox/src/labelbox/schema/id_type.py @@ -15,10 +15,11 @@ class BaseStrEnum(str, Enum): class IdType(BaseStrEnum): """ The type of id used to identify a data row. - + Currently supported types are: - DataRowId: The id assigned to a data row by Labelbox. - GlobalKey: The id assigned to a data row by the user. """ + DataRowId = "ID" GlobalKey = "GKEY" diff --git a/libs/labelbox/src/labelbox/schema/identifiables.py b/libs/labelbox/src/labelbox/schema/identifiables.py index 73a6c4bb3..590ac70c9 100644 --- a/libs/labelbox/src/labelbox/schema/identifiables.py +++ b/libs/labelbox/src/labelbox/schema/identifiables.py @@ -4,7 +4,6 @@ class Identifiables: - def __init__(self, iterable, id_type: str): """ Args: @@ -36,7 +35,10 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: if not isinstance(other, Identifiables): return False - return self._iterable == other._iterable and self._id_type == other._id_type + return ( + self._iterable == other._iterable + and self._id_type == other._id_type + ) class UniqueIds(Identifiables): diff --git a/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py b/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py index 62962d70d..817a02561 100644 --- a/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py +++ b/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py @@ -2,8 +2,14 @@ from typing import List -from labelbox.schema.internal.data_row_upsert_item import DataRowItemBase, DataRowUpsertItem, DataRowCreateItem -from labelbox.schema.internal.descriptor_file_creator import DescriptorFileCreator +from labelbox.schema.internal.data_row_upsert_item import ( + DataRowItemBase, + DataRowUpsertItem, + DataRowCreateItem, +) +from labelbox.schema.internal.descriptor_file_creator import ( + DescriptorFileCreator, +) from pydantic import BaseModel @@ -16,22 +22,27 @@ class UploadManifest(BaseModel): SOURCE_SDK = "SDK" -def upload_in_chunks(client, specs: List[DataRowItemBase], - file_upload_thread_count: int, - max_chunk_size_bytes: int) -> UploadManifest: +def upload_in_chunks( + client, + specs: List[DataRowItemBase], + file_upload_thread_count: int, + max_chunk_size_bytes: int, +) -> UploadManifest: empty_specs = list(filter(lambda spec: spec.is_empty(), specs)) if empty_specs: ids = list(map(lambda spec: spec.id.get("value"), empty_specs)) ids = list(filter(lambda x: x is not None and len(x) > 0, ids)) if len(ids) > 0: raise ValueError( - f"The following items have an empty payload: {ids}") + f"The following items have an empty payload: {ids}" + ) else: # case of create items raise ValueError("Some items have an empty payload") chunk_uris = DescriptorFileCreator(client).create( - specs, max_chunk_size_bytes=max_chunk_size_bytes) + specs, max_chunk_size_bytes=max_chunk_size_bytes + ) - return UploadManifest(source=SOURCE_SDK, - item_count=len(specs), - chunk_uris=chunk_uris) + return UploadManifest( + source=SOURCE_SDK, item_count=len(specs), chunk_uris=chunk_uris + ) diff --git a/libs/labelbox/src/labelbox/schema/internal/data_row_upsert_item.py b/libs/labelbox/src/labelbox/schema/internal/data_row_upsert_item.py index 5759ca818..cc9bbb2c3 100644 --- a/libs/labelbox/src/labelbox/schema/internal/data_row_upsert_item.py +++ b/libs/labelbox/src/labelbox/schema/internal/data_row_upsert_item.py @@ -16,30 +16,30 @@ class DataRowItemBase(ABC, BaseModel): payload: dict @abstractmethod - def is_empty(self) -> bool: - ... + def is_empty(self) -> bool: ... @classmethod def build( cls, dataset_id: str, items: List[dict], - key_types: Optional[Tuple[type, ...]] = () + key_types: Optional[Tuple[type, ...]] = (), ) -> List["DataRowItemBase"]: upload_items = [] for item in items: # enforce current dataset's id for all specs - item['dataset_id'] = dataset_id - key = item.pop('key', None) + item["dataset_id"] = dataset_id + key = item.pop("key", None) if not key: - key = {'type': 'AUTO', 'value': ''} + key = {"type": "AUTO", "value": ""} elif isinstance(key, key_types): # type: ignore - key = {'type': key.id_type.value, 'value': key.key} + key = {"type": key.id_type.value, "value": key.key} else: if not key_types: raise ValueError( - f"Can not have a key for this item, got: {key}") + f"Can not have a key for this item, got: {key}" + ) raise ValueError( f"Key must be an instance of {', '.join([t.__name__ for t in key_types])}, got: {type(item['key']).__name__}" ) @@ -51,27 +51,28 @@ def build( class DataRowUpsertItem(DataRowItemBase): - def is_empty(self) -> bool: """ The payload is considered empty if it's actually empty or the only key is `dataset_id`. :return: bool """ - return (not self.payload or - len(self.payload.keys()) == 1 and "dataset_id" in self.payload) + return ( + not self.payload + or len(self.payload.keys()) == 1 + and "dataset_id" in self.payload + ) @classmethod def build( cls, dataset_id: str, items: List[dict], - key_types: Optional[Tuple[type, ...]] = (UniqueId, GlobalKey) + key_types: Optional[Tuple[type, ...]] = (UniqueId, GlobalKey), ) -> List["DataRowItemBase"]: return super().build(dataset_id, items, (UniqueId, GlobalKey)) class DataRowCreateItem(DataRowItemBase): - def is_empty(self) -> bool: """ The payload is considered empty if it's actually empty or row_data is empty @@ -79,22 +80,28 @@ def is_empty(self) -> bool: :return: bool """ row_data = self.payload.get("row_data", None) or self.payload.get( - DataRow.row_data, None) + DataRow.row_data, None + ) - return (not self._is_legacy_conversational_data() and - (not self.payload or len(self.payload.keys()) == 1 and - "dataset_id" in self.payload or row_data is None or - len(row_data) == 0)) + return not self._is_legacy_conversational_data() and ( + not self.payload + or len(self.payload.keys()) == 1 + and "dataset_id" in self.payload + or row_data is None + or len(row_data) == 0 + ) def _is_legacy_conversational_data(self) -> bool: - return "conversationalData" in self.payload.keys( - ) or "conversational_data" in self.payload.keys() + return ( + "conversationalData" in self.payload.keys() + or "conversational_data" in self.payload.keys() + ) @classmethod def build( cls, dataset_id: str, items: List[dict], - key_types: Optional[Tuple[type, ...]] = () + key_types: Optional[Tuple[type, ...]] = (), ) -> List["DataRowItemBase"]: return super().build(dataset_id, items, ()) diff --git a/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py b/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py index 07128fdd1..ce3ce4b35 100644 --- a/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py +++ b/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py @@ -12,10 +12,15 @@ from labelbox.orm.model import Field from labelbox.schema.embedding import EmbeddingVector from labelbox.schema.internal.datarow_upload_constants import ( - FILE_UPLOAD_THREAD_COUNT) -from labelbox.schema.internal.data_row_upsert_item import DataRowItemBase, DataRowUpsertItem + FILE_UPLOAD_THREAD_COUNT, +) +from labelbox.schema.internal.data_row_upsert_item import ( + DataRowItemBase, + DataRowUpsertItem, +) from typing import TYPE_CHECKING + if TYPE_CHECKING: from labelbox import Client @@ -40,19 +45,25 @@ def create(self, items, max_chunk_size_bytes=None) -> List[str]: json_chunks = self._chunk_down_by_bytes(items, max_chunk_size_bytes) with ThreadPoolExecutor(FILE_UPLOAD_THREAD_COUNT) as executor: futures = [ - executor.submit(self.client.upload_data, chunk, - "application/json", "json_import.json") + executor.submit( + self.client.upload_data, + chunk, + "application/json", + "json_import.json", + ) for chunk in json_chunks ] return [future.result() for future in as_completed(futures)] def create_one(self, items) -> List[str]: - items = self._prepare_items_for_upload(items,) + items = self._prepare_items_for_upload( + items, + ) # Prepare and upload the descriptor file data = json.dumps(items) - return self.client.upload_data(data, - content_type="application/json", - filename="json_import.json") + return self.client.upload_data( + data, content_type="application/json", filename="json_import.json" + ) def _prepare_items_for_upload(self, items, is_upsert=False): """ @@ -99,20 +110,20 @@ def _prepare_items_for_upload(self, items, is_upsert=False): AssetAttachment = Entity.AssetAttachment def upload_if_necessary(item): - if is_upsert and 'row_data' not in item: + if is_upsert and "row_data" not in item: # When upserting, row_data is not required return item - row_data = item['row_data'] + row_data = item["row_data"] if isinstance(row_data, str) and os.path.exists(row_data): item_url = self.client.upload_file(row_data) - item['row_data'] = item_url - if 'external_id' not in item: + item["row_data"] = item_url + if "external_id" not in item: # Default `external_id` to local file name - item['external_id'] = row_data + item["external_id"] = row_data return item def validate_attachments(item): - attachments = item.get('attachments') + attachments = item.get("attachments") if attachments: if isinstance(attachments, list): for attachment in attachments: @@ -139,18 +150,25 @@ def validate_conversational_data(conversational_data: list) -> None: """ def check_message_keys(message): - accepted_message_keys = set([ - "messageId", "timestampUsec", "content", "user", "align", - "canLabel" - ]) + accepted_message_keys = set( + [ + "messageId", + "timestampUsec", + "content", + "user", + "align", + "canLabel", + ] + ) for key in message.keys(): if not key in accepted_message_keys: raise KeyError( f"Invalid {key} key found! Accepted keys in messages list is {accepted_message_keys}" ) - if conversational_data and not isinstance(conversational_data, - list): + if conversational_data and not isinstance( + conversational_data, list + ): raise ValueError( f"conversationalData must be a list. Found {type(conversational_data)}" ) @@ -158,11 +176,12 @@ def check_message_keys(message): [check_message_keys(message) for message in conversational_data] def parse_metadata_fields(item): - metadata_fields = item.get('metadata_fields') + metadata_fields = item.get("metadata_fields") if metadata_fields: mdo = self.client.get_data_row_metadata_ontology() - item['metadata_fields'] = mdo.parse_upsert_metadata( - metadata_fields) + item["metadata_fields"] = mdo.parse_upsert_metadata( + metadata_fields + ) def format_row(item): # Formats user input into a consistent dict structure @@ -182,19 +201,28 @@ def format_row(item): return item def validate_keys(item): - if not is_upsert and 'row_data' not in item: + if not is_upsert and "row_data" not in item: raise InvalidQueryError( - "`row_data` missing when creating DataRow.") + "`row_data` missing when creating DataRow." + ) - if isinstance(item.get('row_data'), - str) and item.get('row_data').startswith("s3:/"): + if isinstance(item.get("row_data"), str) and item.get( + "row_data" + ).startswith("s3:/"): raise InvalidQueryError( - "row_data: s3 assets must start with 'https'.") + "row_data: s3 assets must start with 'https'." + ) allowed_extra_fields = { - 'attachments', 'media_type', 'dataset_id', 'embeddings' + "attachments", + "media_type", + "dataset_id", + "embeddings", } - invalid_keys = set(item) - {f.name for f in DataRow.fields() - } - allowed_extra_fields + invalid_keys = ( + set(item) + - {f.name for f in DataRow.fields()} + - allowed_extra_fields + ) if invalid_keys: raise InvalidAttributeError(DataRow, invalid_keys) return item @@ -210,12 +238,11 @@ def format_legacy_conversational_data(item): global_key = item.pop("globalKey") item["globalKey"] = global_key validate_conversational_data(messages) - one_conversation = \ - { - "type": type, - "version": version, - "messages": messages - } + one_conversation = { + "type": type, + "version": version, + "messages": messages, + } item["row_data"] = one_conversation return item @@ -246,7 +273,7 @@ def convert_item(data_row_item): item = upload_if_necessary(item) if isinstance(data_row_item, DataRowItemBase): - return {'id': data_row_item.id, 'payload': item} + return {"id": data_row_item.id, "payload": item} else: return item @@ -261,8 +288,9 @@ def convert_item(data_row_item): return items - def _chunk_down_by_bytes(self, items: List[dict], - max_chunk_size: int) -> Generator[str, None, None]: + def _chunk_down_by_bytes( + self, items: List[dict], max_chunk_size: int + ) -> Generator[str, None, None]: """ Recursively chunks down a list of items into smaller lists until each list is less than or equal to max_chunk_size bytes NOTE: if one data row is larger than max_chunk_size, it will be returned as one chunk diff --git a/libs/labelbox/src/labelbox/schema/invite.py b/libs/labelbox/src/labelbox/schema/invite.py index 266e14c7f..c89a8b08c 100644 --- a/libs/labelbox/src/labelbox/schema/invite.py +++ b/libs/labelbox/src/labelbox/schema/invite.py @@ -22,6 +22,7 @@ class Invite(DbObject): """ An object representing a user invite """ + created_at = Field.DateTime("created_at") organization_role_name = Field.String("organization_role_name") email = Field.String("email", "inviteeEmail") @@ -31,7 +32,9 @@ def __init__(self, client, invite_response): super().__init__(client, invite_response) self.project_roles = [ - ProjectRole(project=client.get_project(r['projectId']), - role=client.get_roles()[format_role( - r['projectRoleName'])]) for r in project_roles + ProjectRole( + project=client.get_project(r["projectId"]), + role=client.get_roles()[format_role(r["projectRoleName"])], + ) + for r in project_roles ] diff --git a/libs/labelbox/src/labelbox/schema/label.py b/libs/labelbox/src/labelbox/schema/label.py index 7a7d2dc51..371193a13 100644 --- a/libs/labelbox/src/labelbox/schema/label.py +++ b/libs/labelbox/src/labelbox/schema/label.py @@ -10,7 +10,7 @@ class Label(DbObject, Updateable, BulkDeletable): - """ Label represents an assessment on a DataRow. For example one label could + """Label represents an assessment on a DataRow. For example one label could contain 100 bounding boxes (annotations). Attributes: @@ -46,7 +46,7 @@ def __init__(self, *args, **kwargs): @staticmethod def bulk_delete(labels) -> None: - """ Deletes all the given Labels. + """Deletes all the given Labels. Args: labels (list of Label): The Labels to delete. @@ -54,7 +54,7 @@ def bulk_delete(labels) -> None: BulkDeletable._bulk_delete(labels, False) def create_review(self, **kwargs) -> "Review": - """ Creates a Review for this label. + """Creates a Review for this label. Args: **kwargs: Review attributes. At a minimum, a `Review.score` field value must be provided. @@ -64,7 +64,7 @@ def create_review(self, **kwargs) -> "Review": return self.client._create(Entity.Review, kwargs) def create_benchmark(self) -> "Benchmark": - """ Creates a Benchmark for this Label. + """Creates a Benchmark for this Label. Returns: The newly created Benchmark. @@ -72,7 +72,9 @@ def create_benchmark(self) -> "Benchmark": label_id_param = "labelId" query_str = """mutation CreateBenchmarkPyApi($%s: ID!) { createBenchmark(data: {labelId: $%s}) {%s}} """ % ( - label_id_param, label_id_param, - query.results_query_part(Entity.Benchmark)) + label_id_param, + label_id_param, + query.results_query_part(Entity.Benchmark), + ) res = self.client.execute(query_str, {label_id_param: self.uid}) return Entity.Benchmark(self.client, res["createBenchmark"]) diff --git a/libs/labelbox/src/labelbox/schema/labeling_frontend.py b/libs/labelbox/src/labelbox/schema/labeling_frontend.py index 147148ece..49bc8825f 100644 --- a/libs/labelbox/src/labelbox/schema/labeling_frontend.py +++ b/libs/labelbox/src/labelbox/schema/labeling_frontend.py @@ -3,7 +3,7 @@ class LabelingFrontend(DbObject): - """ Label editor. + """Label editor. Represents an HTML / JavaScript UI that is used to generate labels. “Editor” is the default Labeling Frontend that comes in every @@ -16,13 +16,14 @@ class LabelingFrontend(DbObject): projects (Relationship): `ToMany` relationship to Project """ + name = Field.String("name") description = Field.String("description") iframe_url_path = Field.String("iframe_url_path") class LabelingFrontendOptions(DbObject): - """ Label interface options. + """Label interface options. Attributes: customization_options (str) @@ -31,6 +32,7 @@ class LabelingFrontendOptions(DbObject): labeling_frontend (Relationship): `ToOne` relationship to LabelingFrontend organization (Relationship): `ToOne` relationship to Organization """ + customization_options = Field.String("customization_options") project = Relationship.ToOne("Project") diff --git a/libs/labelbox/src/labelbox/schema/labeling_service.py b/libs/labelbox/src/labelbox/schema/labeling_service.py index 70376f2e8..a7a1845be 100644 --- a/libs/labelbox/src/labelbox/schema/labeling_service.py +++ b/libs/labelbox/src/labelbox/schema/labeling_service.py @@ -16,6 +16,7 @@ class LabelingService(_CamelCaseMixin): """ Labeling service for a project. This is a service that can be requested to label data for a project. """ + id: Cuid project_id: Cuid created_at: datetime @@ -28,10 +29,11 @@ def __init__(self, **kwargs): super().__init__(**kwargs) if not self.client.enable_experimental: raise RuntimeError( - "Please enable experimental in client to use LabelingService") + "Please enable experimental in client to use LabelingService" + ) @classmethod - def start(cls, client, project_id: Cuid) -> 'LabelingService': + def start(cls, client, project_id: Cuid) -> "LabelingService": """ Starts the labeling service for the project. This is equivalent to a UI action to Request Specialized Labelers @@ -52,7 +54,7 @@ def start(cls, client, project_id: Cuid) -> 'LabelingService': return cls.get(client, project_id) @classmethod - def get(cls, client, project_id: Cuid) -> 'LabelingService': + def get(cls, client, project_id: Cuid) -> "LabelingService": """ Returns the labeling service associated with the project. @@ -74,14 +76,15 @@ def get(cls, client, project_id: Cuid) -> 'LabelingService': result = client.execute(query, {"projectId": project_id}) if result["projectBoostWorkforce"] is None: raise ResourceNotFoundError( - message="The project does not have a labeling service.") + message="The project does not have a labeling service." + ) data = result["projectBoostWorkforce"] data["client"] = client return LabelingService(**data) - def request(self) -> 'LabelingService': + def request(self) -> "LabelingService": """ - Creates a request to labeling service to start labeling for the project. + Creates a request to labeling service to start labeling for the project. Our back end will validate that the project is ready for labeling and then request the labeling service. Returns: @@ -100,15 +103,18 @@ def request(self) -> 'LabelingService': } } """ - result = self.client.execute(query_str, {"projectId": self.project_id}, - raise_return_resource_not_found=True) + result = self.client.execute( + query_str, + {"projectId": self.project_id}, + raise_return_resource_not_found=True, + ) success = result["validateAndRequestProjectBoostWorkforce"]["success"] if not success: raise Exception("Failed to start labeling service") return LabelingService.get(self.client, self.project_id) @classmethod - def getOrCreate(cls, client, project_id: Cuid) -> 'LabelingService': + def getOrCreate(cls, client, project_id: Cuid) -> "LabelingService": """ Returns the labeling service associated with the project. If the project does not have a labeling service, it will create one. @@ -127,4 +133,4 @@ def dashboard(self) -> LabelingServiceDashboard: Raises: ResourceNotFoundError: If the project does not have a labeling service. """ - return LabelingServiceDashboard.get(self.client, self.project_id) \ No newline at end of file + return LabelingServiceDashboard.get(self.client, self.project_id) diff --git a/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py b/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py index 41ce1f4d1..10a956a66 100644 --- a/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py +++ b/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py @@ -64,6 +64,7 @@ class LabelingServiceDashboard(_CamelCaseMixin): editor_task_type (EditorTaskType): editor task type of the project client (Any): labelbox client """ + id: str = Field(frozen=True) name: str = Field(frozen=True) created_at: Optional[datetime] = Field(frozen=True, default=None) @@ -83,7 +84,8 @@ def __init__(self, **kwargs): super().__init__(**kwargs) if not self.client.enable_experimental: raise RuntimeError( - "Please enable experimental in client to use LabelingService") + "Please enable experimental in client to use LabelingService" + ) @property def service_type(self): @@ -96,22 +98,34 @@ def service_type(self): if self.editor_task_type is None: return sentence_case(self.media_type.value) - if self.editor_task_type == EditorTaskType.OfflineModelChatEvaluation and self.media_type == MediaType.Conversational: + if ( + self.editor_task_type == EditorTaskType.OfflineModelChatEvaluation + and self.media_type == MediaType.Conversational + ): return "Offline chat evaluation" - if self.editor_task_type == EditorTaskType.ModelChatEvaluation and self.media_type == MediaType.Conversational: + if ( + self.editor_task_type == EditorTaskType.ModelChatEvaluation + and self.media_type == MediaType.Conversational + ): return "Live chat evaluation" - if self.editor_task_type == EditorTaskType.ResponseCreation and self.media_type == MediaType.Text: + if ( + self.editor_task_type == EditorTaskType.ResponseCreation + and self.media_type == MediaType.Text + ): return "Response creation" - if self.media_type == MediaType.LLMPromptCreation or self.media_type == MediaType.LLMPromptResponseCreation: + if ( + self.media_type == MediaType.LLMPromptCreation + or self.media_type == MediaType.LLMPromptResponseCreation + ): return "Prompt response creation" return sentence_case(self.media_type.value) @classmethod - def get(cls, client, project_id: str) -> 'LabelingServiceDashboard': + def get(cls, client, project_id: str) -> "LabelingServiceDashboard": """ Returns the labeling service associated with the project. @@ -140,7 +154,6 @@ def get_all( client, search_query: Optional[List[SearchFilter]] = None, ) -> PaginatedCollection: - if search_query is not None: template = Template( """query SearchProjectsPyApi($$first: Int, $$from: String) { @@ -150,7 +163,8 @@ def get_all( pageInfo { endCursor } } } - """) + """ + ) else: template = Template( """query SearchProjectsPyApi($$first: Int, $$from: String) { @@ -160,46 +174,48 @@ def get_all( pageInfo { endCursor } } } - """) + """ + ) query_str = template.substitute( labeling_dashboard_selections=GRAPHQL_QUERY_SELECTIONS, search_query=build_search_filter(search_query) - if search_query else None, + if search_query + else None, ) params: Dict[str, Union[str, int]] = {} def convert_to_labeling_service_dashboard(client, data): - data['client'] = client + data["client"] = client return LabelingServiceDashboard(**data) return PaginatedCollection( client=client, query=query_str, params=params, - dereferencing=['searchProjects', 'nodes'], + dereferencing=["searchProjects", "nodes"], obj_class=convert_to_labeling_service_dashboard, - cursor_path=['searchProjects', 'pageInfo', 'endCursor'], + cursor_path=["searchProjects", "pageInfo", "endCursor"], experimental=True, ) @root_validator(pre=True) def convert_boost_data(cls, data): - if 'boostStatus' in data: - data['status'] = LabelingServiceStatus(data.pop('boostStatus')) + if "boostStatus" in data: + data["status"] = LabelingServiceStatus(data.pop("boostStatus")) - if 'boostRequestedAt' in data: - data['created_at'] = data.pop('boostRequestedAt') + if "boostRequestedAt" in data: + data["created_at"] = data.pop("boostRequestedAt") - if 'boostUpdatedAt' in data: - data['updated_at'] = data.pop('boostUpdatedAt') + if "boostUpdatedAt" in data: + data["updated_at"] = data.pop("boostUpdatedAt") - if 'boostRequestedBy' in data: - data['created_by_id'] = data.pop('boostRequestedBy') + if "boostRequestedBy" in data: + data["created_by_id"] = data.pop("boostRequestedBy") return data def dict(self, *args, **kwargs): row = super().dict(*args, **kwargs) - row.pop('client') - row['service_type'] = self.service_type + row.pop("client") + row["service_type"] = self.service_type return row diff --git a/libs/labelbox/src/labelbox/schema/labeling_service_status.py b/libs/labelbox/src/labelbox/schema/labeling_service_status.py index 62cfd938e..c15cf73b9 100644 --- a/libs/labelbox/src/labelbox/schema/labeling_service_status.py +++ b/libs/labelbox/src/labelbox/schema/labeling_service_status.py @@ -2,12 +2,12 @@ class LabelingServiceStatus(Enum): - Accepted = 'ACCEPTED' - Calibration = 'CALIBRATION' - Complete = 'COMPLETE' - Production = 'PRODUCTION' - Requested = 'REQUESTED' - SetUp = 'SET_UP' + Accepted = "ACCEPTED" + Calibration = "CALIBRATION" + Complete = "COMPLETE" + Production = "PRODUCTION" + Requested = "REQUESTED" + SetUp = "SET_UP" Missing = None @classmethod @@ -15,10 +15,10 @@ def is_supported(cls, value): return isinstance(value, cls) @classmethod - def _missing_(cls, value) -> 'LabelingServiceStatus': + def _missing_(cls, value) -> "LabelingServiceStatus": """Handle missing null new task types - Handle upper case names for compatibility with - the GraphQL""" + Handle upper case names for compatibility with + the GraphQL""" if value is None: return cls.Missing diff --git a/libs/labelbox/src/labelbox/schema/media_type.py b/libs/labelbox/src/labelbox/schema/media_type.py index 99807522b..ae0bbbb3f 100644 --- a/libs/labelbox/src/labelbox/schema/media_type.py +++ b/libs/labelbox/src/labelbox/schema/media_type.py @@ -27,9 +27,9 @@ class MediaType(Enum): @classmethod def _missing_(cls, value): """Handle missing null data types for projects - created without setting allowedMediaType - Handle upper case names for compatibility with - the GraphQL""" + created without setting allowedMediaType + Handle upper case names for compatibility with + the GraphQL""" if value is None: return cls.Unknown @@ -46,9 +46,11 @@ def matches(value, name): value_underscore = value.replace("-", "_") camel_case_value = camel_case(value_underscore) - return (value_upper == name_upper or - value_underscore.upper() == name_upper or - camel_case_value.upper() == name_upper) + return ( + value_upper == name_upper + or value_underscore.upper() == name_upper + or camel_case_value.upper() == name_upper + ) for name, member in cls.__members__.items(): if matches(value, name): @@ -58,18 +60,23 @@ def matches(value, name): @classmethod def is_supported(cls, value): - return isinstance(value, - cls) and value not in [cls.Unknown, cls.Unsupported] + return isinstance(value, cls) and value not in [ + cls.Unknown, + cls.Unsupported, + ] @classmethod def get_supported_members(cls): return [ - item for item in cls.__members__ + item + for item in cls.__members__ if item not in ["Unknown", "Unsupported"] ] def get_media_type_validation_error(media_type): - return TypeError(f"{media_type} is not a valid media type. Use" - f" any of {MediaType.get_supported_members()}" - " from MediaType. Example: MediaType.Image.") + return TypeError( + f"{media_type} is not a valid media type. Use" + f" any of {MediaType.get_supported_members()}" + " from MediaType. Example: MediaType.Image." + ) diff --git a/libs/labelbox/src/labelbox/schema/model.py b/libs/labelbox/src/labelbox/schema/model.py index 692f43fad..e78620002 100644 --- a/libs/labelbox/src/labelbox/schema/model.py +++ b/libs/labelbox/src/labelbox/schema/model.py @@ -9,18 +9,18 @@ class Model(DbObject): """A model represents a program that has been trained and - can make predictions on new data. - Attributes: - name (str) - model_runs (Relationship): `ToMany` relationship to ModelRun - """ + can make predictions on new data. + Attributes: + name (str) + model_runs (Relationship): `ToMany` relationship to ModelRun + """ name = Field.String("name") ontology_id = Field.String("ontology_id") model_runs = Relationship.ToMany("ModelRun", False) def create_model_run(self, name, config=None) -> "ModelRun": - """ Creates a model run belonging to this model. + """Creates a model run belonging to this model. Args: name (string): The name for the model run. @@ -34,17 +34,22 @@ def create_model_run(self, name, config=None) -> "ModelRun": ModelRun = Entity.ModelRun query_str = """mutation CreateModelRunPyApi($%s: String!, $%s: Json, $%s: ID!) { createModelRun(data: {name: $%s, trainingMetadata: $%s, modelId: $%s}) {%s}}""" % ( - name_param, config_param, model_id_param, name_param, config_param, - model_id_param, query.results_query_part(ModelRun)) - res = self.client.execute(query_str, { - name_param: name, - config_param: config, - model_id_param: self.uid - }) + name_param, + config_param, + model_id_param, + name_param, + config_param, + model_id_param, + query.results_query_part(ModelRun), + ) + res = self.client.execute( + query_str, + {name_param: name, config_param: config, model_id_param: self.uid}, + ) return ModelRun(self.client, res["createModelRun"]) def delete(self) -> None: - """ Deletes specified model. + """Deletes specified model. Returns: Query execution success. diff --git a/libs/labelbox/src/labelbox/schema/model_config.py b/libs/labelbox/src/labelbox/schema/model_config.py index 46c0deca9..369315cd0 100644 --- a/libs/labelbox/src/labelbox/schema/model_config.py +++ b/libs/labelbox/src/labelbox/schema/model_config.py @@ -3,9 +3,9 @@ class ModelConfig(DbObject): - """ A ModelConfig represents a set of inference params configured for a model + """A ModelConfig represents a set of inference params configured for a model - Attributes: + Attributes: inference_params (JSON): Dict of inference params model_id (str): ID of the model to configure name (str): Name of config diff --git a/libs/labelbox/src/labelbox/schema/model_run.py b/libs/labelbox/src/labelbox/schema/model_run.py index 7f8714008..73c013b57 100644 --- a/libs/labelbox/src/labelbox/schema/model_run.py +++ b/libs/labelbox/src/labelbox/schema/model_run.py @@ -5,7 +5,16 @@ import warnings from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, Dict, Iterable, Union, Tuple, List, Optional, Any +from typing import ( + TYPE_CHECKING, + Dict, + Iterable, + Union, + Tuple, + List, + Optional, + Any, +) import requests @@ -14,12 +23,17 @@ from labelbox.orm.model import Field, Relationship, Entity from labelbox.orm.query import results_query_part from labelbox.pagination import PaginatedCollection -from labelbox.schema.conflict_resolution_strategy import ConflictResolutionStrategy +from labelbox.schema.conflict_resolution_strategy import ( + ConflictResolutionStrategy, +) from labelbox.schema.export_params import ModelRunExportParams from labelbox.schema.export_task import ExportTask from labelbox.schema.identifiables import UniqueIds, GlobalKeys, DataRowIds -from labelbox.schema.send_to_annotate_params import SendToAnnotateFromModelParams, build_destination_task_queue_input, \ - build_predictions_input +from labelbox.schema.send_to_annotate_params import ( + SendToAnnotateFromModelParams, + build_destination_task_queue_input, + build_predictions_input, +) from labelbox.schema.task import Task if TYPE_CHECKING: @@ -53,10 +67,12 @@ class Status(Enum): COMPLETE = "COMPLETE" FAILED = "FAILED" - def upsert_labels(self, - label_ids: Optional[List[str]] = None, - project_id: Optional[str] = None, - timeout_seconds=3600): + def upsert_labels( + self, + label_ids: Optional[List[str]] = None, + project_id: Optional[str] = None, + timeout_seconds=3600, + ): """ Adds data rows and labels to a Model Run @@ -75,7 +91,8 @@ def upsert_labels(self, if not use_label_ids and not use_project_id: raise ValueError( - "Must provide at least one label id or a project id") + "Must provide at least one label id or a project id" + ) if use_label_ids and use_project_id: raise ValueError("Must only one of label ids, project id") @@ -83,60 +100,64 @@ def upsert_labels(self, if use_label_ids: return self._upsert_labels_by_label_ids(label_ids, timeout_seconds) else: # use_project_id - return self._upsert_labels_by_project_id(project_id, - timeout_seconds) + return self._upsert_labels_by_project_id( + project_id, timeout_seconds + ) - def _upsert_labels_by_label_ids(self, label_ids: List[str], - timeout_seconds: int): - mutation_name = 'createMEAModelRunLabelRegistrationTask' + def _upsert_labels_by_label_ids( + self, label_ids: List[str], timeout_seconds: int + ): + mutation_name = "createMEAModelRunLabelRegistrationTask" create_task_query_str = """mutation createMEAModelRunLabelRegistrationTaskPyApi($modelRunId: ID!, $labelIds : [ID!]!) { %s(where : { id : $modelRunId}, data : {labelIds: $labelIds})} """ % (mutation_name) - res = self.client.execute(create_task_query_str, { - 'modelRunId': self.uid, - 'labelIds': label_ids - }) + res = self.client.execute( + create_task_query_str, + {"modelRunId": self.uid, "labelIds": label_ids}, + ) task_id = res[mutation_name] status_query_str = """query MEALabelRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){ MEALabelRegistrationTaskStatus(where: $where) {status errorMessage} } """ - return self._wait_until_done(lambda: self.client.execute( - status_query_str, {'where': { - 'id': task_id - }})['MEALabelRegistrationTaskStatus'], - timeout_seconds=timeout_seconds) - - def _upsert_labels_by_project_id(self, project_id: str, - timeout_seconds: int): - mutation_name = 'createMEAModelRunProjectLabelRegistrationTask' + return self._wait_until_done( + lambda: self.client.execute( + status_query_str, {"where": {"id": task_id}} + )["MEALabelRegistrationTaskStatus"], + timeout_seconds=timeout_seconds, + ) + + def _upsert_labels_by_project_id( + self, project_id: str, timeout_seconds: int + ): + mutation_name = "createMEAModelRunProjectLabelRegistrationTask" create_task_query_str = """mutation createMEAModelRunProjectLabelRegistrationTaskPyApi($modelRunId: ID!, $projectId : ID!) { %s(where : { modelRunId : $modelRunId, projectId: $projectId})} """ % (mutation_name) - res = self.client.execute(create_task_query_str, { - 'modelRunId': self.uid, - 'projectId': project_id - }) + res = self.client.execute( + create_task_query_str, + {"modelRunId": self.uid, "projectId": project_id}, + ) task_id = res[mutation_name] status_query_str = """query MEALabelRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){ MEALabelRegistrationTaskStatus(where: $where) {status errorMessage} } """ - return self._wait_until_done(lambda: self.client.execute( - status_query_str, {'where': { - 'id': task_id - }})['MEALabelRegistrationTaskStatus'], - timeout_seconds=timeout_seconds) - - def upsert_data_rows(self, - data_row_ids=None, - global_keys=None, - timeout_seconds=3600): - """ Adds data rows to a Model Run without any associated labels + return self._wait_until_done( + lambda: self.client.execute( + status_query_str, {"where": {"id": task_id}} + )["MEALabelRegistrationTaskStatus"], + timeout_seconds=timeout_seconds, + ) + + def upsert_data_rows( + self, data_row_ids=None, global_keys=None, timeout_seconds=3600 + ): + """Adds data rows to a Model Run without any associated labels Args: data_row_ids (list): data row ids to add to model run global_keys (list): global keys for data rows to add to model run @@ -145,37 +166,40 @@ def upsert_data_rows(self, ID of newly generated async task """ - mutation_name = 'createMEAModelRunDataRowRegistrationTask' + mutation_name = "createMEAModelRunDataRowRegistrationTask" create_task_query_str = """mutation createMEAModelRunDataRowRegistrationTaskPyApi($modelRunId: ID!, $dataRowIds: [ID!], $globalKeys: [ID!]) { %s(where : { id : $modelRunId}, data : {dataRowIds: $dataRowIds, globalKeys: $globalKeys})} """ % (mutation_name) res = self.client.execute( - create_task_query_str, { - 'modelRunId': self.uid, - 'dataRowIds': data_row_ids, - 'globalKeys': global_keys - }) + create_task_query_str, + { + "modelRunId": self.uid, + "dataRowIds": data_row_ids, + "globalKeys": global_keys, + }, + ) task_id = res[mutation_name] status_query_str = """query MEADataRowRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){ MEADataRowRegistrationTaskStatus(where: $where) {status errorMessage} } """ - return self._wait_until_done(lambda: self.client.execute( - status_query_str, {'where': { - 'id': task_id - }})['MEADataRowRegistrationTaskStatus'], - timeout_seconds=timeout_seconds) + return self._wait_until_done( + lambda: self.client.execute( + status_query_str, {"where": {"id": task_id}} + )["MEADataRowRegistrationTaskStatus"], + timeout_seconds=timeout_seconds, + ) def _wait_until_done(self, status_fn, timeout_seconds=120, sleep_time=5): # Do not use this function outside of the scope of upsert_data_rows or upsert_labels. It could change. original_timeout = timeout_seconds while True: res = status_fn() - if res['status'] == 'COMPLETE': + if res["status"] == "COMPLETE": return True - elif res['status'] == 'FAILED': + elif res["status"] == "FAILED": raise Exception(f"Job failed.") timeout_seconds -= sleep_time if timeout_seconds <= 0: @@ -190,7 +214,7 @@ def upsert_predictions_and_send_to_project( predictions: Union[str, Path, Iterable[Dict]], project_id: str, priority: Optional[int] = 5, - ) -> 'MEAPredictionImport': # type: ignore + ) -> "MEAPredictionImport": # type: ignore """ Provides a convenient way to execute the following steps in a single function call: 1. Upload predictions to a Model @@ -230,11 +254,14 @@ def upsert_predictions_and_send_to_project( import_job = self.add_predictions(name, predictions) prediction_statuses = import_job.statuses mea_to_mal_data_rows = list( - set([ - row['dataRow']['id'] - for row in prediction_statuses - if row['status'] == 'SUCCESS' - ])) + set( + [ + row["dataRow"]["id"] + for row in prediction_statuses + if row["status"] == "SUCCESS" + ] + ) + ) if not mea_to_mal_data_rows: # 0 successful model predictions imported @@ -254,10 +281,13 @@ def upsert_predictions_and_send_to_project( return import_job, None, None try: - mal_prediction_import = Entity.MEAToMALPredictionImport.create_for_model_run_data_rows( - data_row_ids=mea_to_mal_data_rows, - project_id=project_id, - **kwargs) + mal_prediction_import = ( + Entity.MEAToMALPredictionImport.create_for_model_run_data_rows( + data_row_ids=mea_to_mal_data_rows, + project_id=project_id, + **kwargs, + ) + ) mal_prediction_import.wait_until_done() except Exception as e: logger.warning( @@ -272,7 +302,7 @@ def add_predictions( self, name: str, predictions: Union[str, Path, Iterable[Dict], Iterable["Label"]], - ) -> 'MEAPredictionImport': # type: ignore + ) -> "MEAPredictionImport": # type: ignore """ Uploads predictions to a new Editor project. @@ -289,17 +319,21 @@ def add_predictions( kwargs = dict(client=self.client, id=self.uid, name=name) if isinstance(predictions, str) or isinstance(predictions, Path): if os.path.exists(predictions): - return Entity.MEAPredictionImport.create(path=str(predictions), - **kwargs) + return Entity.MEAPredictionImport.create( + path=str(predictions), **kwargs + ) else: - return Entity.MEAPredictionImport.create(url=str(predictions), - **kwargs) + return Entity.MEAPredictionImport.create( + url=str(predictions), **kwargs + ) elif isinstance(predictions, Iterable): - return Entity.MEAPredictionImport.create(labels=predictions, - **kwargs) + return Entity.MEAPredictionImport.create( + labels=predictions, **kwargs + ) else: raise ValueError( - f'Invalid predictions given of type: {type(predictions)}') + f"Invalid predictions given of type: {type(predictions)}" + ) def model_run_data_rows(self): query_str = """query modelRunPyApi($modelRunId: ID!, $from : String, $first: Int){ @@ -308,13 +342,16 @@ def model_run_data_rows(self): } """ % (results_query_part(ModelRunDataRow)) return PaginatedCollection( - self.client, query_str, {'modelRunId': self.uid}, - ['annotationGroups', 'nodes'], + self.client, + query_str, + {"modelRunId": self.uid}, + ["annotationGroups", "nodes"], lambda client, res: ModelRunDataRow(client, self.model_id, res), - ['annotationGroups', 'pageInfo', 'endCursor']) + ["annotationGroups", "pageInfo", "endCursor"], + ) def delete(self): - """ Deletes specified Model Run. + """Deletes specified Model Run. Returns: Query execution success. @@ -325,7 +362,7 @@ def delete(self): self.client.execute(query_str, {ids_param: str(self.uid)}) def delete_model_run_data_rows(self, data_row_ids: List[str]): - """ Deletes data rows from Model Runs. + """Deletes data rows from Model Runs. Args: data_row_ids (list): List of data row ids to delete from the Model Run. @@ -336,136 +373,150 @@ def delete_model_run_data_rows(self, data_row_ids: List[str]): data_row_ids_param = "dataRowIds" query_str = """mutation DeleteModelRunDataRowsPyApi($%s: ID!, $%s: [ID!]!) { deleteModelRunDataRows(where: {modelRunId: $%s, dataRowIds: $%s})}""" % ( - model_run_id_param, data_row_ids_param, model_run_id_param, - data_row_ids_param) - self.client.execute(query_str, { - model_run_id_param: self.uid, - data_row_ids_param: data_row_ids - }) + model_run_id_param, + data_row_ids_param, + model_run_id_param, + data_row_ids_param, + ) + self.client.execute( + query_str, + {model_run_id_param: self.uid, data_row_ids_param: data_row_ids}, + ) @experimental - def assign_data_rows_to_split(self, - data_row_ids: List[str] = None, - split: Union[DataSplit, str] = None, - global_keys: List[str] = None, - timeout_seconds=120): - + def assign_data_rows_to_split( + self, + data_row_ids: List[str] = None, + split: Union[DataSplit, str] = None, + global_keys: List[str] = None, + timeout_seconds=120, + ): split_value = split.value if isinstance(split, DataSplit) else split valid_splits = DataSplit._member_names_ if split_value is None or split_value not in valid_splits: raise ValueError( - f"`split` must be one of : `{valid_splits}`. Found : `{split}`") + f"`split` must be one of : `{valid_splits}`. Found : `{split}`" + ) task_id = self.client.execute( """mutation assignDataSplitPyApi($modelRunId: ID!, $data: CreateAssignDataRowsToDataSplitTaskInput!){ createAssignDataRowsToDataSplitTask(modelRun : {id: $modelRunId}, data: $data)} - """, { - 'modelRunId': self.uid, - 'data': { - 'assignments': [{ - 'split': split_value, - 'dataRowIds': data_row_ids, - 'globalKeys': global_keys, - }] - } + """, + { + "modelRunId": self.uid, + "data": { + "assignments": [ + { + "split": split_value, + "dataRowIds": data_row_ids, + "globalKeys": global_keys, + } + ] + }, }, - experimental=True)['createAssignDataRowsToDataSplitTask'] + experimental=True, + )["createAssignDataRowsToDataSplitTask"] status_query_str = """query assignDataRowsToDataSplitTaskStatusPyApi($id: ID!){ assignDataRowsToDataSplitTaskStatus(where: {id : $id}){status errorMessage}} """ - return self._wait_until_done(lambda: self.client.execute( - status_query_str, {'id': task_id}, experimental=True)[ - 'assignDataRowsToDataSplitTaskStatus'], - timeout_seconds=timeout_seconds) + return self._wait_until_done( + lambda: self.client.execute( + status_query_str, {"id": task_id}, experimental=True + )["assignDataRowsToDataSplitTaskStatus"], + timeout_seconds=timeout_seconds, + ) @experimental - def update_status(self, - status: Union[str, "ModelRun.Status"], - metadata: Optional[Dict[str, str]] = None, - error_message: Optional[str] = None): - - status_value = status.value if isinstance(status, - ModelRun.Status) else status + def update_status( + self, + status: Union[str, "ModelRun.Status"], + metadata: Optional[Dict[str, str]] = None, + error_message: Optional[str] = None, + ): + status_value = ( + status.value if isinstance(status, ModelRun.Status) else status + ) if status_value not in ModelRun.Status._member_names_: raise ValueError( f"Status must be one of : `{ModelRun.Status._member_names_}`. Found : `{status_value}`" ) - data: Dict[str, Any] = {'status': status_value} + data: Dict[str, Any] = {"status": status_value} if error_message: - data['errorMessage'] = error_message + data["errorMessage"] = error_message if metadata: - data['metadata'] = metadata + data["metadata"] = metadata self.client.execute( """mutation setPipelineStatusPyApi($modelRunId: ID!, $data: UpdateTrainingPipelineInput!){ updateTrainingPipeline(modelRun: {id : $modelRunId}, data: $data){status} } - """, { - 'modelRunId': self.uid, - 'data': data - }, - experimental=True) + """, + {"modelRunId": self.uid, "data": data}, + experimental=True, + ) @experimental def update_config(self, config: Dict[str, Any]) -> Dict[str, Any]: """ - Updates the Model Run's training metadata config - Args: - config (dict): A dictionary of keys and values - Returns: - Model Run id and updated training metadata - """ - data: Dict[str, Any] = {'config': config} + Updates the Model Run's training metadata config + Args: + config (dict): A dictionary of keys and values + Returns: + Model Run id and updated training metadata + """ + data: Dict[str, Any] = {"config": config} res = self.client.execute( """mutation updateModelRunConfigPyApi($modelRunId: ID!, $data: UpdateModelRunConfigInput!){ updateModelRunConfig(modelRun: {id : $modelRunId}, data: $data){trainingMetadata} } - """, { - 'modelRunId': self.uid, - 'data': data - }, - experimental=True) + """, + {"modelRunId": self.uid, "data": data}, + experimental=True, + ) return res["updateModelRunConfig"] @experimental def reset_config(self) -> Dict[str, Any]: """ - Resets Model Run's training metadata config - Returns: - Model Run id and reset training metadata - """ + Resets Model Run's training metadata config + Returns: + Model Run id and reset training metadata + """ res = self.client.execute( """mutation resetModelRunConfigPyApi($modelRunId: ID!){ resetModelRunConfig(modelRun: {id : $modelRunId}){trainingMetadata} } - """, {'modelRunId': self.uid}, - experimental=True) + """, + {"modelRunId": self.uid}, + experimental=True, + ) return res["resetModelRunConfig"] @experimental def get_config(self) -> Dict[str, Any]: """ - Gets Model Run's training metadata - Returns: - training metadata as a dictionary - """ - res = self.client.execute("""query ModelRunPyApi($modelRunId: ID!){ + Gets Model Run's training metadata + Returns: + training metadata as a dictionary + """ + res = self.client.execute( + """query ModelRunPyApi($modelRunId: ID!){ modelRun(where: {id : $modelRunId}){trainingMetadata} } - """, {'modelRunId': self.uid}, - experimental=True) + """, + {"modelRunId": self.uid}, + experimental=True, + ) return res["modelRun"]["trainingMetadata"] @experimental def export_labels( - self, - download: bool = False, - timeout_seconds: int = 600 + self, download: bool = False, timeout_seconds: int = 600 ) -> Optional[Union[str, List[Dict[Any, Any]]]]: """ Experimental. To use, make sure client has enable_experimental=True. @@ -482,7 +533,8 @@ def export_labels( """ warnings.warn( "You are currently utilizing exports v1 for this action, which will be deprecated after April 30th, 2024. We recommend transitioning to exports v2. To view export v2 details, visit our docs: https://docs.labelbox.com/reference/label-export", - DeprecationWarning) + DeprecationWarning, + ) sleep_time = 2 query_str = """mutation exportModelRunAnnotationsPyApi($modelRunId: ID!) { exportModelRunAnnotations(data: {modelRunId: $modelRunId}) { @@ -493,8 +545,8 @@ def export_labels( while True: url = self.client.execute( - query_str, {'modelRunId': self.uid}, - experimental=True)['exportModelRunAnnotations']['downloadUrl'] + query_str, {"modelRunId": self.uid}, experimental=True + )["exportModelRunAnnotations"]["downloadUrl"] if url: if not download: @@ -508,13 +560,16 @@ def export_labels( if timeout_seconds <= 0: return None - logger.debug("ModelRun '%s' label export, waiting for server...", - self.uid) + logger.debug( + "ModelRun '%s' label export, waiting for server...", self.uid + ) time.sleep(sleep_time) - def export(self, - task_name: Optional[str] = None, - params: Optional[ModelRunExportParams] = None) -> ExportTask: + def export( + self, + task_name: Optional[str] = None, + params: Optional[ModelRunExportParams] = None, + ) -> ExportTask: """ Creates a model run export task with the given params and returns the task. @@ -536,7 +591,7 @@ def export_v2( """ task, is_streamable = self._export(task_name, params) - if (is_streamable): + if is_streamable: return ExportTask(task, True) return task @@ -550,48 +605,50 @@ def _export( create_task_query_str = ( f"mutation {mutation_name}PyApi" f"($input: ExportDataRowsInModelRunInput!)" - f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}") + f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}" + ) _params = params or ModelRunExportParams() query_params = { "input": { "taskName": task_name, - "filters": { - "modelRunId": self.uid - }, + "filters": {"modelRunId": self.uid}, "isStreamableReady": True, "params": { - "mediaTypeOverride": - _params.get('media_type_override', None), - "includeAttachments": - _params.get('attachments', False), - "includeEmbeddings": - _params.get('embeddings', False), - "includeMetadata": - _params.get('metadata_fields', False), - "includeDataRowDetails": - _params.get('data_row_details', False), - "includePredictions": - _params.get('predictions', False), - "includeModelRunDetails": - _params.get('model_run_details', False), + "mediaTypeOverride": _params.get( + "media_type_override", None + ), + "includeAttachments": _params.get("attachments", False), + "includeEmbeddings": _params.get("embeddings", False), + "includeMetadata": _params.get("metadata_fields", False), + "includeDataRowDetails": _params.get( + "data_row_details", False + ), + "includePredictions": _params.get("predictions", False), + "includeModelRunDetails": _params.get( + "model_run_details", False + ), }, - "streamable": streamable + "streamable": streamable, } } - res = self.client.execute(create_task_query_str, - query_params, - error_log_key="errors") + res = self.client.execute( + create_task_query_str, query_params, error_log_key="errors" + ) res = res[mutation_name] task_id = res["taskId"] is_streamable = res["isStreamable"] return Task.get_task(self.client, task_id), is_streamable def send_to_annotate_from_model( - self, destination_project_id: str, task_queue_id: Optional[str], - batch_name: str, data_rows: Union[DataRowIds, GlobalKeys], - params: SendToAnnotateFromModelParams) -> Task: + self, + destination_project_id: str, + task_queue_id: Optional[str], + batch_name: str, + data_rows: Union[DataRowIds, GlobalKeys], + params: SendToAnnotateFromModelParams, + ) -> Task: """ Sends data rows from a model run to a project for annotation. @@ -625,46 +682,46 @@ def send_to_annotate_from_model( """ destination_task_queue = build_destination_task_queue_input( - task_queue_id) + task_queue_id + ) data_rows_query = self.client.build_catalog_query(data_rows) predictions_ontology_mapping = params.get( - "predictions_ontology_mapping", None) + "predictions_ontology_mapping", None + ) predictions_input = build_predictions_input( - predictions_ontology_mapping, self.uid) + predictions_ontology_mapping, self.uid + ) batch_priority = params.get("batch_priority", 5) exclude_data_rows_in_project = params.get( - "exclude_data_rows_in_project", False) + "exclude_data_rows_in_project", False + ) override_existing_annotations_rule = params.get( "override_existing_annotations_rule", - ConflictResolutionStrategy.KeepExisting) + ConflictResolutionStrategy.KeepExisting, + ) res = self.client.execute( - mutation_str, { + mutation_str, + { "input": { - "destinationProjectId": - destination_project_id, + "destinationProjectId": destination_project_id, "batchInput": { "batchName": batch_name, - "batchPriority": batch_priority + "batchPriority": batch_priority, }, - "destinationTaskQueue": - destination_task_queue, - "excludeDataRowsInProject": - exclude_data_rows_in_project, - "annotationsInput": - None, - "predictionsInput": - predictions_input, - "conflictLabelsResolutionStrategy": - override_existing_annotations_rule, + "destinationTaskQueue": destination_task_queue, + "excludeDataRowsInProject": exclude_data_rows_in_project, + "annotationsInput": None, + "predictionsInput": predictions_input, + "conflictLabelsResolutionStrategy": override_existing_annotations_rule, "searchQuery": [data_rows_query], - "sourceModelRunId": - self.uid + "sourceModelRunId": self.uid, } - })['sendToAnnotateFromMea'] + }, + )["sendToAnnotateFromMea"] - return Entity.Task.get_task(self.client, res['taskId']) + return Entity.Task.get_task(self.client, res["taskId"]) class ModelRunDataRow(DbObject): diff --git a/libs/labelbox/src/labelbox/schema/ontology.py b/libs/labelbox/src/labelbox/schema/ontology.py index 7b74acdc2..efe32611b 100644 --- a/libs/labelbox/src/labelbox/schema/ontology.py +++ b/libs/labelbox/src/labelbox/schema/ontology.py @@ -13,9 +13,12 @@ import json from pydantic import StringConstraints -FeatureSchemaId: Type[str] = Annotated[str, StringConstraints(min_length=25, - max_length=25)] -SchemaId: Type[str] = Annotated[str, StringConstraints(min_length=25, max_length=25)] +FeatureSchemaId: Type[str] = Annotated[ + str, StringConstraints(min_length=25, max_length=25) +] +SchemaId: Type[str] = Annotated[ + str, StringConstraints(min_length=25, max_length=25) +] class DeleteFeatureFromOntologyResult: @@ -23,8 +26,10 @@ class DeleteFeatureFromOntologyResult: deleted: bool def __str__(self): - return "<%s %s>" % (self.__class__.__name__.split(".")[-1], - json.dumps(self.__dict__)) + return "<%s %s>" % ( + self.__class__.__name__.split(".")[-1], + json.dumps(self.__dict__), + ) class FeatureSchema(DbObject): @@ -50,11 +55,14 @@ class Option: feature_schema_id: (str) options: (list) """ + value: Union[str, int] label: Optional[Union[str, int]] = None schema_id: Optional[str] = None feature_schema_id: Optional[FeatureSchemaId] = None - options: Union[List["Classification"], List["PromptResponseClassification"]] = field(default_factory=list) + options: Union[ + List["Classification"], List["PromptResponseClassification"] + ] = field(default_factory=list) def __post_init__(self): if self.label is None: @@ -62,17 +70,18 @@ def __post_init__(self): @classmethod def from_dict( - cls, - dictionary: Dict[str, - Any]) -> Dict[Union[str, int], Union[str, int]]: - return cls(value=dictionary["value"], - label=dictionary["label"], - schema_id=dictionary.get("schemaNodeId", None), - feature_schema_id=dictionary.get("featureSchemaId", None), - options=[ - Classification.from_dict(o) - for o in dictionary.get("options", []) - ]) + cls, dictionary: Dict[str, Any] + ) -> Dict[Union[str, int], Union[str, int]]: + return cls( + value=dictionary["value"], + label=dictionary["label"], + schema_id=dictionary.get("schemaNodeId", None), + feature_schema_id=dictionary.get("featureSchemaId", None), + options=[ + Classification.from_dict(o) + for o in dictionary.get("options", []) + ], + ) def asdict(self) -> Dict[str, Any]: return { @@ -80,20 +89,23 @@ def asdict(self) -> Dict[str, Any]: "featureSchemaId": self.feature_schema_id, "label": self.label, "value": self.value, - "options": [o.asdict(is_subclass=True) for o in self.options] + "options": [o.asdict(is_subclass=True) for o in self.options], } - def add_option(self, option: Union["Classification", "PromptResponseClassification"]) -> None: + def add_option( + self, option: Union["Classification", "PromptResponseClassification"] + ) -> None: if option.name in (o.name for o in self.options): raise InconsistentOntologyException( f"Duplicate nested classification '{option.name}' " - f"for option '{self.label}'") + f"for option '{self.label}'" + ) self.options.append(option) @dataclass class Classification: - """ + """ A classification to be added to a Project's ontology. The classification is dependent on the Classification Type. @@ -135,7 +147,7 @@ class Type(Enum): class Scope(Enum): GLOBAL = "global" INDEX = "index" - + class UIMode(Enum): HOTKEY = "hotkey" SEARCHABLE = "searchable" @@ -150,7 +162,9 @@ class UIMode(Enum): schema_id: Optional[str] = None feature_schema_id: Optional[str] = None scope: Scope = None - ui_mode: Optional[UIMode] = None # How this classification should be answered (e.g. hotkeys / autocomplete, etc) + ui_mode: Optional[UIMode] = ( + None # How this classification should be answered (e.g. hotkeys / autocomplete, etc) + ) def __post_init__(self): if self.name is None: @@ -159,7 +173,8 @@ def __post_init__(self): "for the classification schema name, which will be used when " "creating annotation payload for Model-Assisted Labeling " "Import and Label Import. “instructions” is no longer " - "supported to specify classification schema name.") + "supported to specify classification schema name." + ) if self.instructions is not None: self.name = self.instructions warnings.warn(msg) @@ -171,21 +186,25 @@ def __post_init__(self): @classmethod def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: - return cls(class_type=cls.Type(dictionary["type"]), - name=dictionary["name"], - instructions=dictionary["instructions"], - required=dictionary.get("required", False), - options=[Option.from_dict(o) for o in dictionary["options"]], - ui_mode=cls.UIMode(dictionary["uiMode"]) if "uiMode" in dictionary else None, - schema_id=dictionary.get("schemaNodeId", None), - feature_schema_id=dictionary.get("featureSchemaId", None), - scope=cls.Scope(dictionary.get("scope", cls.Scope.GLOBAL))) + return cls( + class_type=cls.Type(dictionary["type"]), + name=dictionary["name"], + instructions=dictionary["instructions"], + required=dictionary.get("required", False), + options=[Option.from_dict(o) for o in dictionary["options"]], + ui_mode=cls.UIMode(dictionary["uiMode"]) + if "uiMode" in dictionary + else None, + schema_id=dictionary.get("schemaNodeId", None), + feature_schema_id=dictionary.get("featureSchemaId", None), + scope=cls.Scope(dictionary.get("scope", cls.Scope.GLOBAL)), + ) def asdict(self, is_subclass: bool = False) -> Dict[str, Any]: - if self.class_type in self._REQUIRES_OPTIONS \ - and len(self.options) < 1: + if self.class_type in self._REQUIRES_OPTIONS and len(self.options) < 1: raise InconsistentOntologyException( - f"Classification '{self.name}' requires options.") + f"Classification '{self.name}' requires options." + ) classification = { "type": self.class_type.value, "instructions": self.instructions, @@ -193,24 +212,32 @@ def asdict(self, is_subclass: bool = False) -> Dict[str, Any]: "required": self.required, "options": [o.asdict() for o in self.options], "schemaNodeId": self.schema_id, - "featureSchemaId": self.feature_schema_id + "featureSchemaId": self.feature_schema_id, } - if (self.class_type == self.Type.RADIO or self.class_type == self.Type.CHECKLIST) and self.ui_mode: + if ( + self.class_type == self.Type.RADIO + or self.class_type == self.Type.CHECKLIST + ) and self.ui_mode: # added because this key does nothing for text so no point of including classification["uiMode"] = self.ui_mode.value if is_subclass: return classification - classification[ - "scope"] = self.scope.value if self.scope is not None else self.Scope.GLOBAL.value + classification["scope"] = ( + self.scope.value + if self.scope is not None + else self.Scope.GLOBAL.value + ) return classification def add_option(self, option: Option) -> None: if option.value in (o.value for o in self.options): raise InconsistentOntologyException( f"Duplicate option '{option.value}' " - f"for classification '{self.name}'.") + f"for classification '{self.name}'." + ) self.options.append(option) - + + @dataclass class ResponseOption(Option): """ @@ -228,26 +255,27 @@ class ResponseOption(Option): feature_schema_id: (str) options: (list) """ - + @classmethod def from_dict( - cls, - dictionary: Dict[str, - Any]) -> Dict[Union[str, int], Union[str, int]]: - return cls(value=dictionary["value"], - label=dictionary["label"], - schema_id=dictionary.get("schemaNodeId", None), - feature_schema_id=dictionary.get("featureSchemaId", None), - options=[ - PromptResponseClassification.from_dict(o) - for o in dictionary.get("options", []) - ]) + cls, dictionary: Dict[str, Any] + ) -> Dict[Union[str, int], Union[str, int]]: + return cls( + value=dictionary["value"], + label=dictionary["label"], + schema_id=dictionary.get("schemaNodeId", None), + feature_schema_id=dictionary.get("featureSchemaId", None), + options=[ + PromptResponseClassification.from_dict(o) + for o in dictionary.get("options", []) + ], + ) @dataclass class PromptResponseClassification: """ - + A PromptResponseClassification to be added to a Project's ontology. The classification is dependent on the PromptResponseClassification Type. @@ -268,7 +296,7 @@ class PromptResponseClassification: >>> classification_two = PromptResponseClassification( >>> class_type = PromptResponseClassification.Type.RESPONSE_RADIO, >>> name = "Second Example") - + >>> classification_two.add_option(ResponseOption( >>> value = "Option Example")) @@ -283,7 +311,7 @@ class PromptResponseClassification: schema_id: (str) feature_schema_id: (str) """ - + def __post_init__(self): if self.name is None: msg = ( @@ -291,7 +319,8 @@ def __post_init__(self): "for the classification schema name, which will be used when " "creating annotation payload for Model-Assisted Labeling " "Import and Label Import. “instructions” is no longer " - "supported to specify classification schema name.") + "supported to specify classification schema name." + ) if self.instructions is not None: self.name = self.instructions warnings.warn(msg) @@ -303,7 +332,7 @@ def __post_init__(self): class Type(Enum): PROMPT = "prompt" - RESPONSE_TEXT= "response-text" + RESPONSE_TEXT = "response-text" RESPONSE_CHECKLIST = "response-checklist" RESPONSE_RADIO = "response-radio" @@ -321,31 +350,38 @@ class Type(Enum): @classmethod def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: - return cls(class_type=cls.Type(dictionary["type"]), - name=dictionary["name"], - instructions=dictionary["instructions"], - required=True, # always required - options=[ResponseOption.from_dict(o) for o in dictionary["options"]], - character_min=dictionary.get("minCharacters", None), - character_max=dictionary.get("maxCharacters", None), - schema_id=dictionary.get("schemaNodeId", None), - feature_schema_id=dictionary.get("featureSchemaId", None)) + return cls( + class_type=cls.Type(dictionary["type"]), + name=dictionary["name"], + instructions=dictionary["instructions"], + required=True, # always required + options=[ + ResponseOption.from_dict(o) for o in dictionary["options"] + ], + character_min=dictionary.get("minCharacters", None), + character_max=dictionary.get("maxCharacters", None), + schema_id=dictionary.get("schemaNodeId", None), + feature_schema_id=dictionary.get("featureSchemaId", None), + ) def asdict(self, is_subclass: bool = False) -> Dict[str, Any]: - if self.class_type in self._REQUIRES_OPTIONS \ - and len(self.options) < 1: + if self.class_type in self._REQUIRES_OPTIONS and len(self.options) < 1: raise InconsistentOntologyException( - f"Response Classification '{self.name}' requires options.") + f"Response Classification '{self.name}' requires options." + ) classification = { "type": self.class_type.value, "instructions": self.instructions, "name": self.name, - "required": True, # always required + "required": True, # always required "options": [o.asdict() for o in self.options], "schemaNodeId": self.schema_id, - "featureSchemaId": self.feature_schema_id + "featureSchemaId": self.feature_schema_id, } - if (self.class_type == self.Type.PROMPT or self.class_type == self.Type.RESPONSE_TEXT): + if ( + self.class_type == self.Type.PROMPT + or self.class_type == self.Type.RESPONSE_TEXT + ): if self.character_min: classification["minCharacters"] = self.character_min if self.character_max: @@ -358,7 +394,8 @@ def add_option(self, option: ResponseOption) -> None: if option.value in (o.value for o in self.options): raise InconsistentOntologyException( f"Duplicate option '{option.value}' " - f"for response classification '{self.name}'.") + f"for response classification '{self.name}'." + ) self.options.append(option) @@ -402,9 +439,9 @@ class Type(Enum): LINE = "line" NER = "named-entity" RELATIONSHIP = "edge" - MESSAGE_SINGLE_SELECTION = 'message-single-selection' - MESSAGE_MULTI_SELECTION = 'message-multi-selection' - MESSAGE_RANKING = 'message-ranking' + MESSAGE_SINGLE_SELECTION = "message-single-selection" + MESSAGE_MULTI_SELECTION = "message-multi-selection" + MESSAGE_RANKING = "message-ranking" tool: Type name: str @@ -416,16 +453,18 @@ class Type(Enum): @classmethod def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: - return cls(name=dictionary['name'], - schema_id=dictionary.get("schemaNodeId", None), - feature_schema_id=dictionary.get("featureSchemaId", None), - required=dictionary.get("required", False), - tool=cls.Type(dictionary["tool"]), - classifications=[ - Classification.from_dict(c) - for c in dictionary["classifications"] - ], - color=dictionary["color"]) + return cls( + name=dictionary["name"], + schema_id=dictionary.get("schemaNodeId", None), + feature_schema_id=dictionary.get("featureSchemaId", None), + required=dictionary.get("required", False), + tool=cls.Type(dictionary["tool"]), + classifications=[ + Classification.from_dict(c) + for c in dictionary["classifications"] + ], + color=dictionary["color"], + ) def asdict(self) -> Dict[str, Any]: return { @@ -437,14 +476,15 @@ def asdict(self) -> Dict[str, Any]: c.asdict(is_subclass=True) for c in self.classifications ], "schemaNodeId": self.schema_id, - "featureSchemaId": self.feature_schema_id + "featureSchemaId": self.feature_schema_id, } def add_classification(self, classification: Classification) -> None: if classification.name in (c.name for c in self.classifications): raise InconsistentOntologyException( f"Duplicate nested classification '{classification.name}' " - f"for tool '{self.name}'") + f"for tool '{self.name}'" + ) self.classifications.append(classification) @@ -477,25 +517,37 @@ class Ontology(DbObject): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._tools: Optional[List[Tool]] = None - self._classifications: Optional[Union[List[Classification],List[PromptResponseClassification]]] = None + self._classifications: Optional[ + Union[List[Classification], List[PromptResponseClassification]] + ] = None def tools(self) -> List[Tool]: """Get list of tools (AKA objects) in an Ontology.""" if self._tools is None: self._tools = [ - Tool.from_dict(tool) for tool in self.normalized['tools'] + Tool.from_dict(tool) for tool in self.normalized["tools"] ] return self._tools - def classifications(self) -> List[Union[Classification, PromptResponseClassification]]: + def classifications( + self, + ) -> List[Union[Classification, PromptResponseClassification]]: """Get list of classifications in an Ontology.""" if self._classifications is None: self._classifications = [] for classification in self.normalized["classifications"]: - if "type" in classification and classification["type"] in PromptResponseClassification.Type._value2member_map_.keys(): - self._classifications.append(PromptResponseClassification.from_dict(classification)) + if ( + "type" in classification + and classification["type"] + in PromptResponseClassification.Type._value2member_map_.keys() + ): + self._classifications.append( + PromptResponseClassification.from_dict(classification) + ) else: - self._classifications.append(Classification.from_dict(classification)) + self._classifications.append( + Classification.from_dict(classification) + ) return self._classifications @@ -524,36 +576,52 @@ class OntologyBuilder: """ + tools: List[Tool] = field(default_factory=list) - classifications: List[Union[Classification, PromptResponseClassification]] = field(default_factory=list) + classifications: List[ + Union[Classification, PromptResponseClassification] + ] = field(default_factory=list) @classmethod def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: classifications = [] for c in dictionary["classifications"]: - if "type" in c and c["type"] in PromptResponseClassification.Type._value2member_map_.keys(): - classifications.append(PromptResponseClassification.from_dict(c)) + if ( + "type" in c + and c["type"] + in PromptResponseClassification.Type._value2member_map_.keys() + ): + classifications.append( + PromptResponseClassification.from_dict(c) + ) else: classifications.append(Classification.from_dict(c)) - return cls(tools=[Tool.from_dict(t) for t in dictionary["tools"]], - classifications=classifications) + return cls( + tools=[Tool.from_dict(t) for t in dictionary["tools"]], + classifications=classifications, + ) def asdict(self) -> Dict[str, Any]: self._update_colors() classifications = [] prompts = 0 for c in self.classifications: - if hasattr(c, "class_type") and c.class_type in PromptResponseClassification.Type: + if ( + hasattr(c, "class_type") + and c.class_type in PromptResponseClassification.Type + ): if c.class_type == PromptResponseClassification.Type.PROMPT: prompts += 1 if prompts > 1: - raise ValueError("Only one prompt is allowed per ontology") + raise ValueError( + "Only one prompt is allowed per ontology" + ) classifications.append(PromptResponseClassification.asdict(c)) else: classifications.append(Classification.asdict(c)) return { "tools": [t.asdict() for t in self.tools], - "classifications": classifications + "classifications": classifications, } def _update_colors(self): @@ -562,9 +630,10 @@ def _update_colors(self): for index in range(num_tools): hsv_color = (index * 1 / num_tools, 1, 1) rgb_color = tuple( - int(255 * x) for x in colorsys.hsv_to_rgb(*hsv_color)) + int(255 * x) for x in colorsys.hsv_to_rgb(*hsv_color) + ) if self.tools[index].color is None: - self.tools[index].color = '#%02x%02x%02x' % rgb_color + self.tools[index].color = "#%02x%02x%02x" % rgb_color @classmethod def from_project(cls, project: "project.Project") -> "OntologyBuilder": @@ -578,11 +647,16 @@ def from_ontology(cls, ontology: Ontology) -> "OntologyBuilder": def add_tool(self, tool: Tool) -> None: if tool.name in (t.name for t in self.tools): raise InconsistentOntologyException( - f"Duplicate tool name '{tool.name}'. ") + f"Duplicate tool name '{tool.name}'. " + ) self.tools.append(tool) - def add_classification(self, classification: Union[Classification, PromptResponseClassification]) -> None: + def add_classification( + self, + classification: Union[Classification, PromptResponseClassification], + ) -> None: if classification.name in (c.name for c in self.classifications): raise InconsistentOntologyException( - f"Duplicate classification name '{classification.name}'. ") + f"Duplicate classification name '{classification.name}'. " + ) self.classifications.append(classification) diff --git a/libs/labelbox/src/labelbox/schema/ontology_kind.py b/libs/labelbox/src/labelbox/schema/ontology_kind.py index 7dd3311cb..3171b811e 100644 --- a/libs/labelbox/src/labelbox/schema/ontology_kind.py +++ b/libs/labelbox/src/labelbox/schema/ontology_kind.py @@ -8,6 +8,7 @@ class OntologyKind(Enum): """ OntologyKind is an enum that represents the different types of ontologies """ + ModelEvaluation = "MODEL_EVALUATION" ResponseCreation = "RESPONSE_CREATION" Missing = None @@ -18,27 +19,31 @@ def is_supported(cls, value): @classmethod def get_ontology_kind_validation_error(cls, ontology_kind): - return TypeError(f"{ontology_kind}: is not a valid ontology kind. Use" - f" any of {OntologyKind.__members__.items()}" - " from OntologyKind.") + return TypeError( + f"{ontology_kind}: is not a valid ontology kind. Use" + f" any of {OntologyKind.__members__.items()}" + " from OntologyKind." + ) @staticmethod def evaluate_ontology_kind_with_media_type( - ontology_kind, - media_type: Optional[MediaType]) -> Union[MediaType, None]: - + ontology_kind, media_type: Optional[MediaType] + ) -> Union[MediaType, None]: ontology_to_media = { - OntologyKind.ModelEvaluation: - (MediaType.Conversational, - "For chat evaluation, media_type must be Conversational."), - OntologyKind.ResponseCreation: - (MediaType.Text, - "For response creation, media_type must be Text.") + OntologyKind.ModelEvaluation: ( + MediaType.Conversational, + "For chat evaluation, media_type must be Conversational.", + ), + OntologyKind.ResponseCreation: ( + MediaType.Text, + "For response creation, media_type must be Text.", + ), } if ontology_kind in ontology_to_media: expected_media_type, error_message = ontology_to_media[ - ontology_kind] + ontology_kind + ] if media_type is None or media_type == expected_media_type: media_type = expected_media_type @@ -59,10 +64,10 @@ def is_supported(cls, value): return isinstance(value, cls) @classmethod - def _missing_(cls, value) -> 'EditorTaskType': + def _missing_(cls, value) -> "EditorTaskType": """Handle missing null new task types - Handle upper case names for compatibility with - the GraphQL""" + Handle upper case names for compatibility with + the GraphQL""" if value is None: return cls.Missing @@ -75,34 +80,45 @@ def _missing_(cls, value) -> 'EditorTaskType': class EditorTaskTypeMapper: - @staticmethod - def to_editor_task_type(ontology_kind: OntologyKind, - media_type: MediaType) -> EditorTaskType: - if ontology_kind and OntologyKind.is_supported( - ontology_kind) and media_type and MediaType.is_supported( - media_type): + def to_editor_task_type( + ontology_kind: OntologyKind, media_type: MediaType + ) -> EditorTaskType: + if ( + ontology_kind + and OntologyKind.is_supported(ontology_kind) + and media_type + and MediaType.is_supported(media_type) + ): editor_task_type = EditorTaskTypeMapper.map_to_editor_task_type( - ontology_kind, media_type) + ontology_kind, media_type + ) else: editor_task_type = EditorTaskType.Missing return editor_task_type @staticmethod - def map_to_editor_task_type(onotology_kind: OntologyKind, - media_type: MediaType) -> EditorTaskType: - if onotology_kind == OntologyKind.ModelEvaluation and media_type == MediaType.Conversational: + def map_to_editor_task_type( + onotology_kind: OntologyKind, media_type: MediaType + ) -> EditorTaskType: + if ( + onotology_kind == OntologyKind.ModelEvaluation + and media_type == MediaType.Conversational + ): return EditorTaskType.ModelChatEvaluation - elif onotology_kind == OntologyKind.ResponseCreation and media_type == MediaType.Text: + elif ( + onotology_kind == OntologyKind.ResponseCreation + and media_type == MediaType.Text + ): return EditorTaskType.ResponseCreation else: return EditorTaskType.Missing class UploadType(Enum): - Auto = 'AUTO', - Manual = 'MANUAL', + Auto = ("AUTO",) + Manual = ("MANUAL",) Missing = None @classmethod @@ -110,7 +126,7 @@ def is_supported(cls, value): return isinstance(value, cls) @classmethod - def _missing_(cls, value: object) -> 'UploadType': + def _missing_(cls, value: object) -> "UploadType": if value is None: return cls.Missing diff --git a/libs/labelbox/src/labelbox/schema/organization.py b/libs/labelbox/src/labelbox/schema/organization.py index 3a5e23efc..71e715f11 100644 --- a/libs/labelbox/src/labelbox/schema/organization.py +++ b/libs/labelbox/src/labelbox/schema/organization.py @@ -9,11 +9,18 @@ from labelbox.schema.resource_tag import ResourceTag if TYPE_CHECKING: - from labelbox import Role, User, ProjectRole, Invite, InviteLimit, IAMIntegration + from labelbox import ( + Role, + User, + ProjectRole, + Invite, + InviteLimit, + IAMIntegration, + ) class Organization(DbObject): - """ An Organization is a group of Users. + """An Organization is a group of Users. It is associated with data created by Users within that Organization. Typically all Users within an Organization have access to data created by any User in the same Organization. @@ -47,10 +54,11 @@ def __init__(self, *args, **kwargs): resource_tags = Relationship.ToMany("ResourceTags", False) def invite_user( - self, - email: str, - role: "Role", - project_roles: Optional[List["ProjectRole"]] = None) -> "Invite": + self, + email: str, + role: "Role", + project_roles: Optional[List["ProjectRole"]] = None, + ) -> "Invite": """ Invite a new member to the org. This will send the user an email invite @@ -76,30 +84,40 @@ def invite_user( data_param = "data" query_str = """mutation createInvitesPyApi($%s: [CreateInviteInput!]){ createInvites(data: $%s){ invite { id createdAt organizationRoleName inviteeEmail inviter { %s } }}}""" % ( - data_param, data_param, query.results_query_part(Entity.User)) - - projects = [{ - "projectId": project_role.project.uid, - "projectRoleId": project_role.role.uid - } for project_role in project_roles or []] + data_param, + data_param, + query.results_query_part(Entity.User), + ) + + projects = [ + { + "projectId": project_role.project.uid, + "projectRoleId": project_role.role.uid, + } + for project_role in project_roles or [] + ] res = self.client.execute( - query_str, { - data_param: [{ - "inviterId": self.client.get_user().uid, - "inviteeEmail": email, - "organizationId": self.uid, - "organizationRoleId": role.uid, - "projects": projects - }] - }) - invite_response = res['createInvites'][0]['invite'] + query_str, + { + data_param: [ + { + "inviterId": self.client.get_user().uid, + "inviteeEmail": email, + "organizationId": self.uid, + "organizationRoleId": role.uid, + "projects": projects, + } + ] + }, + ) + invite_response = res["createInvites"][0]["invite"] if not invite_response: raise LabelboxError(f"Unable to send invite for email {email}") return Entity.Invite(self.client, invite_response) def invite_limit(self) -> InviteLimit: - """ Retrieve invite limits for the org + """Retrieve invite limits for the org This already accounts for users currently in the org Meaining that `used = users + invites, remaining = limit - (users + invites)` @@ -111,10 +129,13 @@ def invite_limit(self) -> InviteLimit: res = self.client.execute( """query InvitesLimitPyApi($%s: ID!) { invitesLimit(where: {id: $%s}) { used limit remaining } - }""" % (org_id_param, org_id_param), {org_id_param: self.uid}) - return InviteLimit(**{ - utils.snake_case(k): v for k, v in res['invitesLimit'].items() - }) + }""" + % (org_id_param, org_id_param), + {org_id_param: self.uid}, + ) + return InviteLimit( + **{utils.snake_case(k): v for k, v in res["invitesLimit"].items()} + ) def remove_user(self, user: "User") -> None: """ @@ -128,7 +149,10 @@ def remove_user(self, user: "User") -> None: self.client.execute( """mutation DeleteMemberPyApi($%s: ID!) { updateUser(where: {id: $%s}, data: {deleted: true}) { id deleted } - }""" % (user_id_param, user_id_param), {user_id_param: user.uid}) + }""" + % (user_id_param, user_id_param), + {user_id_param: user.uid}, + ) def create_resource_tag(self, tag: Dict[str, str]) -> ResourceTag: """ @@ -145,30 +169,38 @@ def create_resource_tag(self, tag: Dict[str, str]) -> ResourceTag: query_str = """mutation CreateResourceTagPyApi($text:String!,$color:String!) { createResourceTag(input:{text:$%s,color:$%s}) {%s}} - """ % (tag_text_param, tag_color_param, - query.results_query_part(ResourceTag)) + """ % ( + tag_text_param, + tag_color_param, + query.results_query_part(ResourceTag), + ) params = { tag_text_param: tag.get("text", None), - tag_color_param: tag.get("color", None) + tag_color_param: tag.get("color", None), } if not all(params.values()): raise ValueError( - f"tag must contain 'text' and 'color' keys. received: {tag}") + f"tag must contain 'text' and 'color' keys. received: {tag}" + ) res = self.client.execute(query_str, params) - return ResourceTag(self.client, res['createResourceTag']) + return ResourceTag(self.client, res["createResourceTag"]) def get_resource_tags(self) -> List[ResourceTag]: """ Returns all resource tags for an organization """ - query_str = """query GetOrganizationResourceTagsPyApi{organization{resourceTag{%s}}}""" % ( - query.results_query_part(ResourceTag)) + query_str = ( + """query GetOrganizationResourceTagsPyApi{organization{resourceTag{%s}}}""" + % (query.results_query_part(ResourceTag)) + ) return [ - ResourceTag(self.client, tag) for tag in self.client.execute( - query_str)['organization']['resourceTag'] + ResourceTag(self.client, tag) + for tag in self.client.execute(query_str)["organization"][ + "resourceTag" + ] ] def get_iam_integrations(self) -> List["IAMIntegration"]: @@ -184,10 +216,12 @@ def get_iam_integrations(self) -> List["IAMIntegration"]: ... on GcpIamIntegrationSettings {serviceAccountEmailId readBucket} } - } } """ % query.results_query_part(Entity.IAMIntegration)) + } } """ + % query.results_query_part(Entity.IAMIntegration) + ) return [ Entity.IAMIntegration(self.client, integration_data) - for integration_data in res['iamIntegrations'] + for integration_data in res["iamIntegrations"] ] def get_default_iam_integration(self) -> Optional["IAMIntegration"]: @@ -197,12 +231,14 @@ def get_default_iam_integration(self) -> Optional["IAMIntegration"]: """ integrations = self.get_iam_integrations() default_integration = [ - integration for integration in integrations + integration + for integration in integrations if integration.is_org_default ] if len(default_integration) > 1: raise ValueError( "Found more than one default signer. Please contact Labelbox to resolve" ) - return None if not len( - default_integration) else default_integration.pop() + return ( + None if not len(default_integration) else default_integration.pop() + ) diff --git a/libs/labelbox/src/labelbox/schema/project.py b/libs/labelbox/src/labelbox/schema/project.py index a30ff856b..a45ddfa4b 100644 --- a/libs/labelbox/src/labelbox/schema/project.py +++ b/libs/labelbox/src/labelbox/schema/project.py @@ -6,19 +6,37 @@ from collections import namedtuple from datetime import datetime, timezone from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, TypeVar, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, + Tuple, + TypeVar, + Union, + overload, +) from urllib.parse import urlparse -from labelbox.schema.labeling_service import LabelingService, LabelingServiceStatus +from labelbox.schema.labeling_service import ( + LabelingService, + LabelingServiceStatus, +) from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard import requests from labelbox import parser from labelbox import utils from labelbox.exceptions import error_message_for_unparsed_graphql_error -from labelbox.exceptions import (InvalidQueryError, LabelboxError, - ProcessingWaitTimeout, ResourceConflict, - ResourceNotFoundError) +from labelbox.exceptions import ( + InvalidQueryError, + LabelboxError, + ProcessingWaitTimeout, + ResourceConflict, + ResourceNotFoundError, +) from labelbox.orm import query from labelbox.orm.db_object import DbObject, Deletable, Updateable, experimental from labelbox.orm.model import Entity, Field, Relationship @@ -26,7 +44,11 @@ from labelbox.schema.consensus_settings import ConsensusSettings from labelbox.schema.create_batches_task import CreateBatchesTask from labelbox.schema.data_row import DataRow -from labelbox.schema.export_filters import ProjectExportFilters, validate_datetime, build_filters +from labelbox.schema.export_filters import ( + ProjectExportFilters, + validate_datetime, + build_filters, +) from labelbox.schema.export_params import ProjectExportParams from labelbox.schema.export_task import ExportTask from labelbox.schema.id_type import IdType @@ -39,24 +61,32 @@ from labelbox.schema.resource_tag import ResourceTag from labelbox.schema.task import Task from labelbox.schema.task_queue import TaskQueue -from labelbox.schema.ontology_kind import (EditorTaskType, OntologyKind, - UploadType) -from labelbox.schema.project_overview import ProjectOverview, ProjectOverviewDetailed +from labelbox.schema.ontology_kind import ( + EditorTaskType, + OntologyKind, + UploadType, +) +from labelbox.schema.project_overview import ( + ProjectOverview, + ProjectOverviewDetailed, +) if TYPE_CHECKING: from labelbox import BulkImportRequest DataRowPriority = int -LabelingParameterOverrideInput = Tuple[Union[DataRow, DataRowIdentifier], - DataRowPriority] +LabelingParameterOverrideInput = Tuple[ + Union[DataRow, DataRowIdentifier], DataRowPriority +] logger = logging.getLogger(__name__) MAX_SYNC_BATCH_ROW_COUNT = 1_000 def validate_labeling_parameter_overrides( - data: List[LabelingParameterOverrideInput]) -> None: + data: List[LabelingParameterOverrideInput], +) -> None: for idx, row in enumerate(data): if len(row) < 2: raise TypeError( @@ -131,11 +161,14 @@ class Project(DbObject, Updateable, Deletable): organization = Relationship.ToOne("Organization", False) labeling_frontend = Relationship.ToOne( "LabelingFrontend", - config=Relationship.Config(disconnect_supported=False)) + config=Relationship.Config(disconnect_supported=False), + ) labeling_frontend_options = Relationship.ToMany( - "LabelingFrontendOptions", False, "labeling_frontend_options") + "LabelingFrontendOptions", False, "labeling_frontend_options" + ) labeling_parameter_overrides = Relationship.ToMany( - "LabelingParameterOverride", False, "labeling_parameter_overrides") + "LabelingParameterOverride", False, "labeling_parameter_overrides" + ) webhooks = Relationship.ToMany("Webhook", False) benchmarks = Relationship.ToMany("Benchmark", False) ontology = Relationship.ToOne("Ontology", True) @@ -148,23 +181,31 @@ def is_chat_evaluation(self) -> bool: Returns: True if this project is a live chat evaluation project, False otherwise """ - return self.media_type == MediaType.Conversational and self.editor_task_type == EditorTaskType.ModelChatEvaluation + return ( + self.media_type == MediaType.Conversational + and self.editor_task_type == EditorTaskType.ModelChatEvaluation + ) def is_prompt_response(self) -> bool: """ Returns: True if this project is a prompt response project, False otherwise """ - return self.media_type == MediaType.LLMPromptResponseCreation or self.media_type == MediaType.LLMPromptCreation or self.editor_task_type == EditorTaskType.ResponseCreation + return ( + self.media_type == MediaType.LLMPromptResponseCreation + or self.media_type == MediaType.LLMPromptCreation + or self.editor_task_type == EditorTaskType.ResponseCreation + ) def is_auto_data_generation(self) -> bool: - return (self.upload_type == UploadType.Auto) # type: ignore + return self.upload_type == UploadType.Auto # type: ignore # we test not only the project ontology is None, but also a default empty ontology that we create when we attach a labeling front end in createLabelingFrontendOptions def is_empty_ontology(self) -> bool: ontology = self.ontology() # type: ignore - return ontology is None or (len(ontology.tools()) == 0 and - len(ontology.classifications()) == 0) + return ontology is None or ( + len(ontology.tools()) == 0 and len(ontology.classifications()) == 0 + ) def project_model_configs(self): query_str = """query ProjectModelConfigsPyApi($id: ID!) { @@ -189,7 +230,7 @@ def project_model_configs(self): ] def update(self, **kwargs): - """ Updates this project with the specified attributes + """Updates this project with the specified attributes Args: kwargs: a dictionary containing attributes to be upserted @@ -214,14 +255,16 @@ def update(self, **kwargs): if MediaType.is_supported(media_type): kwargs["media_type"] = media_type.value else: - raise TypeError(f"{media_type} is not a valid media type. Use" - f" any of {MediaType.get_supported_members()}" - " from MediaType. Example: MediaType.Image.") + raise TypeError( + f"{media_type} is not a valid media type. Use" + f" any of {MediaType.get_supported_members()}" + " from MediaType. Example: MediaType.Image." + ) return super().update(**kwargs) def members(self) -> PaginatedCollection: - """ Fetch all current members for this project + """Fetch all current members for this project Returns: A `PaginatedCollection` of `ProjectMember`s @@ -232,13 +275,18 @@ def members(self) -> PaginatedCollection: project(where: {id : $%s}) { id members(skip: %%d first: %%d){ id user { %s } role { id name } accessFrom } } }""" % (id_param, id_param, query.results_query_part(Entity.User)) - return PaginatedCollection(self.client, query_str, - {id_param: str(self.uid)}, - ["project", "members"], ProjectMember) + return PaginatedCollection( + self.client, + query_str, + {id_param: str(self.uid)}, + ["project", "members"], + ProjectMember, + ) def update_project_resource_tags( - self, resource_tag_ids: List[str]) -> List[ResourceTag]: - """ Creates project resource tags + self, resource_tag_ids: List[str] + ) -> List[ResourceTag]: + """Creates project resource tags Args: resource_tag_ids @@ -250,13 +298,18 @@ def update_project_resource_tags( query_str = """mutation UpdateProjectResourceTagsPyApi($%s:ID!,$%s:[String!]) { project(where:{id:$%s}){updateProjectResourceTags(input:{%s:$%s}){%s}}}""" % ( - project_id_param, tag_ids_param, project_id_param, tag_ids_param, - tag_ids_param, query.results_query_part(ResourceTag)) + project_id_param, + tag_ids_param, + project_id_param, + tag_ids_param, + tag_ids_param, + query.results_query_part(ResourceTag), + ) - res = self.client.execute(query_str, { - project_id_param: self.uid, - tag_ids_param: resource_tag_ids - }) + res = self.client.execute( + query_str, + {project_id_param: self.uid, tag_ids_param: resource_tag_ids}, + ) return [ ResourceTag(self.client, tag) @@ -274,13 +327,14 @@ def get_resource_tags(self) -> List[ResourceTag]: } }""" % (query.results_query_part(ResourceTag)) - results = self.client.execute( - query_str, {"projectId": self.uid})['project']['resourceTags'] + results = self.client.execute(query_str, {"projectId": self.uid})[ + "project" + ]["resourceTags"] return [ResourceTag(self.client, tag) for tag in results] def labels(self, datasets=None, order_by=None) -> PaginatedCollection: - """ Custom relationship expansion method to support limited filtering. + """Custom relationship expansion method to support limited filtering. Args: datasets (iterable of Dataset): Optional collection of Datasets @@ -292,14 +346,17 @@ def labels(self, datasets=None, order_by=None) -> PaginatedCollection: if datasets is not None: where = " where:{dataRow: {dataset: {id_in: [%s]}}}" % ", ".join( - '"%s"' % dataset.uid for dataset in datasets) + '"%s"' % dataset.uid for dataset in datasets + ) else: where = "" if order_by is not None: query.check_order_by_clause(Label, order_by) - order_by_str = "orderBy: %s_%s" % (order_by[0].graphql_name, - order_by[1].name.upper()) + order_by_str = "orderBy: %s_%s" % ( + order_by[0].graphql_name, + order_by[1].name.upper(), + ) else: order_by_str = "" @@ -307,17 +364,25 @@ def labels(self, datasets=None, order_by=None) -> PaginatedCollection: query_str = """query GetProjectLabelsPyApi($%s: ID!) {project (where: {id: $%s}) {labels (skip: %%d first: %%d %s %s) {%s}}}""" % ( - id_param, id_param, where, order_by_str, - query.results_query_part(Label)) + id_param, + id_param, + where, + order_by_str, + query.results_query_part(Label), + ) - return PaginatedCollection(self.client, query_str, {id_param: self.uid}, - ["project", "labels"], Label) + return PaginatedCollection( + self.client, + query_str, + {id_param: self.uid}, + ["project", "labels"], + Label, + ) def export_queued_data_rows( - self, - timeout_seconds=120, - include_metadata: bool = False) -> List[Dict[str, str]]: - """ Returns all data rows that are currently enqueued for this project. + self, timeout_seconds=120, include_metadata: bool = False + ) -> List[Dict[str, str]]: + """Returns all data rows that are currently enqueued for this project. Args: timeout_seconds (float): Max waiting time, in seconds. @@ -329,7 +394,8 @@ def export_queued_data_rows( """ warnings.warn( "You are currently utilizing exports v1 for this action, which will be deprecated after April 30th, 2024. We recommend transitioning to exports v2. To view export v2 details, visit our docs: https://docs.labelbox.com/reference/label-export", - DeprecationWarning) + DeprecationWarning, + ) id_param = "projectId" metadata_param = "includeMetadataInput" query_str = """mutation GetQueuedDataRowsExportUrlPyApi($%s: ID!, $%s: Boolean!) @@ -338,10 +404,10 @@ def export_queued_data_rows( sleep_time = 2 start_time = time.time() while True: - res = self.client.execute(query_str, { - id_param: self.uid, - metadata_param: include_metadata - }) + res = self.client.execute( + query_str, + {id_param: self.uid, metadata_param: include_metadata}, + ) res = res["exportQueuedDataRows"] if res["status"] == "COMPLETE": download_url = res["downloadUrl"] @@ -359,14 +425,14 @@ def export_queued_data_rows( logger.debug( "Project '%s' queued data row export, waiting for server...", - self.uid) + self.uid, + ) time.sleep(sleep_time) - def export_labels(self, - download=False, - timeout_seconds=1800, - **kwargs) -> Optional[Union[str, List[Dict[Any, Any]]]]: - """ Calls the server-side Label exporting that generates a JSON + def export_labels( + self, download=False, timeout_seconds=1800, **kwargs + ) -> Optional[Union[str, List[Dict[Any, Any]]]]: + """Calls the server-side Label exporting that generates a JSON payload, and returns the URL to that payload. Will only generate a new URL at a max frequency of 30 min. @@ -389,7 +455,8 @@ def export_labels(self, """ warnings.warn( "You are currently utilizing exports v1 for this action, which will be deprecated after April 30th, 2024. We recommend transitioning to exports v2. To view export v2 details, visit our docs: https://docs.labelbox.com/reference/label-export", - DeprecationWarning) + DeprecationWarning, + ) def _string_from_dict(dictionary: dict, value_with_quotes=False) -> str: """Returns a concatenated string of the dictionary's keys and values @@ -397,12 +464,14 @@ def _string_from_dict(dictionary: dict, value_with_quotes=False) -> str: The string will be formatted as {key}: 'value' for each key. Value will be inclusive of quotations while key will not. This can be toggled with `value_with_quotes`""" - quote = "\"" if value_with_quotes else "" - return ",".join([ - f"""{c}: {quote}{dictionary.get(c)}{quote}""" - for c in dictionary - if dictionary.get(c) - ]) + quote = '"' if value_with_quotes else "" + return ",".join( + [ + f"""{c}: {quote}{dictionary.get(c)}{quote}""" + for c in dictionary + if dictionary.get(c) + ] + ) sleep_time = 2 id_param = "projectId" @@ -412,15 +481,16 @@ def _string_from_dict(dictionary: dict, value_with_quotes=False) -> str: if "start" in kwargs or "end" in kwargs: created_at_dict = { "start": kwargs.get("start", ""), - "end": kwargs.get("end", "") + "end": kwargs.get("end", ""), } [validate_datetime(date) for date in created_at_dict.values()] filter_param_dict["labelCreatedAt"] = "{%s}" % _string_from_dict( - created_at_dict, value_with_quotes=True) + created_at_dict, value_with_quotes=True + ) if "last_activity_start" in kwargs or "last_activity_end" in kwargs: - last_activity_start = kwargs.get('last_activity_start') - last_activity_end = kwargs.get('last_activity_end') + last_activity_start = kwargs.get("last_activity_start") + last_activity_end = kwargs.get("last_activity_end") if last_activity_start: validate_datetime(str(last_activity_start)) @@ -428,15 +498,14 @@ def _string_from_dict(dictionary: dict, value_with_quotes=False) -> str: validate_datetime(str(last_activity_end)) filter_param_dict["lastActivityAt"] = "{%s}" % _string_from_dict( - { - "start": last_activity_start, - "end": last_activity_end - }, - value_with_quotes=True) + {"start": last_activity_start, "end": last_activity_end}, + value_with_quotes=True, + ) if filter_param_dict: - filter_param = """, filters: {%s }""" % (_string_from_dict( - filter_param_dict, value_with_quotes=False)) + filter_param = """, filters: {%s }""" % ( + _string_from_dict(filter_param_dict, value_with_quotes=False) + ) query_str = """mutation GetLabelExportUrlPyApi($%s: ID!) {exportLabels(data:{projectId: $%s%s}) {downloadUrl createdAt shouldPoll} } @@ -448,7 +517,7 @@ def _string_from_dict(dictionary: dict, value_with_quotes=False) -> str: res = self.client.execute(query_str, {id_param: self.uid}) res = res["exportLabels"] if not res["shouldPoll"] and res["downloadUrl"] is not None: - url = res['downloadUrl'] + url = res["downloadUrl"] if not download: return url else: @@ -460,8 +529,9 @@ def _string_from_dict(dictionary: dict, value_with_quotes=False) -> str: if current_time - start_time > timeout_seconds: return None - logger.debug("Project '%s' label export, waiting for server...", - self.uid) + logger.debug( + "Project '%s' label export, waiting for server...", self.uid + ) time.sleep(sleep_time) def export( @@ -516,7 +586,7 @@ def export_v2( >>> task.result """ task, is_streamable = self._export(task_name, filters, params) - if (is_streamable): + if is_streamable: return ExportTask(task, True) return task @@ -527,34 +597,39 @@ def _export( params: Optional[ProjectExportParams] = None, streamable: bool = False, ) -> Tuple[Task, bool]: - _params = params or ProjectExportParams({ - "attachments": False, - "embeddings": False, - "metadata_fields": False, - "data_row_details": False, - "project_details": False, - "performance_details": False, - "label_details": False, - "media_type_override": None, - "interpolated_frames": False, - }) - - _filters = filters or ProjectExportFilters({ - "last_activity_at": None, - "label_created_at": None, - "data_row_ids": None, - "global_keys": None, - "batch_ids": None, - "workflow_status": None - }) + _params = params or ProjectExportParams( + { + "attachments": False, + "embeddings": False, + "metadata_fields": False, + "data_row_details": False, + "project_details": False, + "performance_details": False, + "label_details": False, + "media_type_override": None, + "interpolated_frames": False, + } + ) + + _filters = filters or ProjectExportFilters( + { + "last_activity_at": None, + "label_created_at": None, + "data_row_ids": None, + "global_keys": None, + "batch_ids": None, + "workflow_status": None, + } + ) mutation_name = "exportDataRowsInProject" create_task_query_str = ( f"mutation {mutation_name}PyApi" f"($input: ExportDataRowsInProjectInput!)" - f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}") + f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}" + ) - media_type_override = _params.get('media_type_override', None) + media_type_override = _params.get("media_type_override", None) query_params: Dict[str, Any] = { "input": { "taskName": task_name, @@ -564,28 +639,28 @@ def _export( "searchQuery": { "scope": None, "query": [], - } + }, }, "params": { - "mediaTypeOverride": - media_type_override.value - if media_type_override is not None else None, - "includeAttachments": - _params.get('attachments', False), - "includeEmbeddings": - _params.get('embeddings', False), - "includeMetadata": - _params.get('metadata_fields', False), - "includeDataRowDetails": - _params.get('data_row_details', False), - "includeProjectDetails": - _params.get('project_details', False), - "includePerformanceDetails": - _params.get('performance_details', False), - "includeLabelDetails": - _params.get('label_details', False), - "includeInterpolatedFrames": - _params.get('interpolated_frames', False), + "mediaTypeOverride": media_type_override.value + if media_type_override is not None + else None, + "includeAttachments": _params.get("attachments", False), + "includeEmbeddings": _params.get("embeddings", False), + "includeMetadata": _params.get("metadata_fields", False), + "includeDataRowDetails": _params.get( + "data_row_details", False + ), + "includeProjectDetails": _params.get( + "project_details", False + ), + "includePerformanceDetails": _params.get( + "performance_details", False + ), + "includeLabelDetails": _params.get("label_details", False), + "includeInterpolatedFrames": _params.get( + "interpolated_frames", False + ), }, "streamable": streamable, } @@ -594,16 +669,16 @@ def _export( search_query = build_filters(self.client, _filters) query_params["input"]["filters"]["searchQuery"]["query"] = search_query - res = self.client.execute(create_task_query_str, - query_params, - error_log_key="errors") + res = self.client.execute( + create_task_query_str, query_params, error_log_key="errors" + ) res = res[mutation_name] task_id = res["taskId"] is_streamable = res["isStreamable"] return Task.get_task(self.client, task_id), is_streamable def export_issues(self, status=None) -> str: - """ Calls the server-side Issues exporting that + """Calls the server-side Issues exporting that returns the URL to that payload. Args: @@ -622,19 +697,19 @@ def export_issues(self, status=None) -> str: valid_statuses = {None, "Open", "Resolved"} if status not in valid_statuses: - raise ValueError("status must be in {}. Found {}".format( - valid_statuses, status)) + raise ValueError( + "status must be in {}. Found {}".format(valid_statuses, status) + ) - res = self.client.execute(query_str, { - id_param: self.uid, - status_param: status - }) + res = self.client.execute( + query_str, {id_param: self.uid, status_param: status} + ) - res = res['project'] + res = res["project"] logger.debug("Project '%s' issues export, link generated", self.uid) - return res.get('issueExportUrl') + return res.get("issueExportUrl") def upsert_instructions(self, instructions_file: str) -> None: """ @@ -660,7 +735,8 @@ def upsert_instructions(self, instructions_file: str) -> None: if frontend.name != "Editor": logger.warning( f"This function has only been tested to work with the Editor front end. Found %s", - frontend.name) + frontend.name, + ) supported_instruction_formats = (".pdf", ".html") if not instructions_file.endswith(supported_instruction_formats): @@ -683,13 +759,13 @@ def upsert_instructions(self, instructions_file: str) -> None: } }""" - self.client.execute(query_str, { - 'projectId': self.uid, - 'instructions_url': instructions_url - }) + self.client.execute( + query_str, + {"projectId": self.uid, "instructions_url": instructions_url}, + ) def labeler_performance(self) -> PaginatedCollection: - """ Returns the labeler performances for this Project. + """Returns the labeler performances for this Project. Returns: A PaginatedCollection of LabelerPerformance objects. @@ -706,17 +782,25 @@ def create_labeler_performance(client, result): result["user"] = Entity.User(client, result["user"]) # python isoformat doesn't accept Z as utc timezone result["lastActivityTime"] = utils.format_iso_from_string( - result["lastActivityTime"].replace('Z', '+00:00')) - return LabelerPerformance(**{ - utils.snake_case(key): value for key, value in result.items() - }) + result["lastActivityTime"].replace("Z", "+00:00") + ) + return LabelerPerformance( + **{ + utils.snake_case(key): value + for key, value in result.items() + } + ) - return PaginatedCollection(self.client, query_str, {id_param: self.uid}, - ["project", "labelerPerformance"], - create_labeler_performance) + return PaginatedCollection( + self.client, + query_str, + {id_param: self.uid}, + ["project", "labelerPerformance"], + create_labeler_performance, + ) def review_metrics(self, net_score) -> int: - """ Returns this Project's review metrics. + """Returns this Project's review metrics. Args: net_score (None or Review.NetScore): Indicates desired metric. @@ -726,7 +810,8 @@ def review_metrics(self, net_score) -> int: if net_score not in (None,) + tuple(Entity.Review.NetScore): raise InvalidQueryError( "Review metrics net score must be either None " - "or one of Review.NetScore values") + "or one of Review.NetScore values" + ) id_param = "projectId" net_score_literal = "None" if net_score is None else net_score.name query_str = """query ProjectReviewMetricsPyApi($%s: ID!){ @@ -758,24 +843,23 @@ def connect_ontology(self, ontology) -> None: if not self.is_empty_ontology(): raise ValueError("Ontology already connected to project.") - if self.labeling_frontend( - ) is None: # Chat evaluation projects are automatically set up via the same api that creates a project - self._connect_default_labeling_front_end(ontology_as_dict={ - "tools": [], - "classifications": [] - }) + if ( + self.labeling_frontend() is None + ): # Chat evaluation projects are automatically set up via the same api that creates a project + self._connect_default_labeling_front_end( + ontology_as_dict={"tools": [], "classifications": []} + ) query_str = """mutation ConnectOntologyPyApi($projectId: ID!, $ontologyId: ID!){ project(where: {id: $projectId}) {connectOntology(ontologyId: $ontologyId) {id}}}""" - self.client.execute(query_str, { - 'ontologyId': ontology.uid, - 'projectId': self.uid - }) + self.client.execute( + query_str, {"ontologyId": ontology.uid, "projectId": self.uid} + ) timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") self.update(setup_complete=timestamp) def setup(self, labeling_frontend, labeling_frontend_options) -> None: - """ This method will associate default labeling frontend with the project and create an ontology based on labeling_frontend_options. + """This method will associate default labeling frontend with the project and create an ontology based on labeling_frontend_options. Args: labeling_frontend (LabelingFrontend): Do not use, this parameter is deprecated. We now associate the default labeling frontend with the project. @@ -804,11 +888,15 @@ def setup(self, labeling_frontend, labeling_frontend_options) -> None: def _connect_default_labeling_front_end(self, ontology_as_dict: dict): labeling_frontend = self.labeling_frontend() - if labeling_frontend is None: # Chat evaluation projects are automatically set up via the same api that creates a project + if ( + labeling_frontend is None + ): # Chat evaluation projects are automatically set up via the same api that creates a project warnings.warn("Connecting default labeling editor for the project.") labeling_frontend = next( self.client.get_labeling_frontends( - where=Entity.LabelingFrontend.name == "Editor")) + where=Entity.LabelingFrontend.name == "Editor" + ) + ) self.labeling_frontend.connect(labeling_frontend) if not isinstance(ontology_as_dict, str): @@ -818,11 +906,13 @@ def _connect_default_labeling_front_end(self, ontology_as_dict: dict): LFO = Entity.LabelingFrontendOptions self.client._create( - LFO, { + LFO, + { LFO.project: self, LFO.labeling_frontend: labeling_frontend, - LFO.customization_options: labeling_frontend_options_str - }) + LFO.customization_options: labeling_frontend_options_str, + }, + ) def create_batch( self, @@ -855,7 +945,8 @@ def create_batch( if self.is_auto_data_generation(): raise ValueError( - "Cannot create batches for auto data generation projects") + "Cannot create batches for auto data generation projects" + ) dr_ids = [] if data_rows is not None: @@ -866,7 +957,8 @@ def create_batch( dr_ids.append(dr) else: raise ValueError( - "`data_rows` must be DataRow ids or DataRow objects") + "`data_rows` must be DataRow ids or DataRow objects" + ) if data_rows is not None: row_count = len(dr_ids) @@ -877,23 +969,28 @@ def create_batch( if row_count > 100_000: raise ValueError( - f"Batch exceeds max size, break into smaller batches") + f"Batch exceeds max size, break into smaller batches" + ) if not row_count: raise ValueError("You need at least one data row in a batch") self._wait_until_data_rows_are_processed( - dr_ids, global_keys, self._wait_processing_max_seconds) + dr_ids, global_keys, self._wait_processing_max_seconds + ) if consensus_settings: - consensus_settings = ConsensusSettings(**consensus_settings).model_dump( - by_alias=True) + consensus_settings = ConsensusSettings( + **consensus_settings + ).model_dump(by_alias=True) if row_count >= MAX_SYNC_BATCH_ROW_COUNT: - return self._create_batch_async(name, dr_ids, global_keys, priority, - consensus_settings) + return self._create_batch_async( + name, dr_ids, global_keys, priority, consensus_settings + ) else: - return self._create_batch_sync(name, dr_ids, global_keys, priority, - consensus_settings) + return self._create_batch_sync( + name, dr_ids, global_keys, priority, consensus_settings + ) def create_batches( self, @@ -936,16 +1033,19 @@ def create_batches( dr_ids.append(dr) else: raise ValueError( - "`data_rows` must be DataRow ids or DataRow objects") + "`data_rows` must be DataRow ids or DataRow objects" + ) self._wait_until_data_rows_are_processed( - dr_ids, global_keys, self._wait_processing_max_seconds) + dr_ids, global_keys, self._wait_processing_max_seconds + ) if consensus_settings: - consensus_settings = ConsensusSettings(**consensus_settings).model_dump( - by_alias=True) + consensus_settings = ConsensusSettings( + **consensus_settings + ).model_dump(by_alias=True) - method = 'createBatches' + method = "createBatches" mutation_str = """mutation %sPyApi($projectId: ID!, $input: CreateBatchesInput!) { project(where: {id: $projectId}) { %s(input: $input) { @@ -965,12 +1065,13 @@ def create_batches( "dataRowIds": dr_ids, "globalKeys": global_keys, "priority": priority, - "consensusSettings": consensus_settings - } + "consensusSettings": consensus_settings, + }, } - tasks = self.client.execute( - mutation_str, params, experimental=True)["project"][method]["tasks"] + tasks = self.client.execute(mutation_str, params, experimental=True)[ + "project" + ][method]["tasks"] batch_ids = [task["batchUuid"] for task in tasks] task_ids = [task["taskId"] for task in tasks] @@ -981,8 +1082,8 @@ def create_batches_from_dataset( name_prefix: str, dataset_id: str, priority: int = 5, - consensus_settings: Optional[Dict[str, - Any]] = None) -> CreateBatchesTask: + consensus_settings: Optional[Dict[str, Any]] = None, + ) -> CreateBatchesTask: """ Creates batches for a project from a dataset, selecting only the data rows that are not already added to the project. When the dataset contains more than 100k data rows and multiple batches are needed, the specific batch @@ -1009,10 +1110,11 @@ def create_batches_from_dataset( raise ValueError("Project must be in batch mode") if consensus_settings: - consensus_settings = ConsensusSettings(**consensus_settings).model_dump( - by_alias=True) + consensus_settings = ConsensusSettings( + **consensus_settings + ).model_dump(by_alias=True) - method = 'createBatchesFromDataset' + method = "createBatchesFromDataset" mutation_str = """mutation %sPyApi($projectId: ID!, $input: CreateBatchesFromDatasetInput!) { project(where: {id: $projectId}) { %s(input: $input) { @@ -1031,21 +1133,23 @@ def create_batches_from_dataset( "batchNamePrefix": name_prefix, "datasetId": dataset_id, "priority": priority, - "consensusSettings": consensus_settings - } + "consensusSettings": consensus_settings, + }, } - tasks = self.client.execute( - mutation_str, params, experimental=True)["project"][method]["tasks"] + tasks = self.client.execute(mutation_str, params, experimental=True)[ + "project" + ][method]["tasks"] batch_ids = [task["batchUuid"] for task in tasks] task_ids = [task["taskId"] for task in tasks] return CreateBatchesTask(self.client, self.uid, batch_ids, task_ids) - def _create_batch_sync(self, name, dr_ids, global_keys, priority, - consensus_settings): - method = 'createBatchV2' + def _create_batch_sync( + self, name, dr_ids, global_keys, priority, consensus_settings + ): + method = "createBatchV2" query_str = """mutation %sPyApi($projectId: ID!, $batchInput: CreateBatchInput!) { project(where: {id: $projectId}) { %s(input: $batchInput) { @@ -1064,28 +1168,30 @@ def _create_batch_sync(self, name, dr_ids, global_keys, priority, "dataRowIds": dr_ids, "globalKeys": global_keys, "priority": priority, - "consensusSettings": consensus_settings - } + "consensusSettings": consensus_settings, + }, } - res = self.client.execute(query_str, - params, - timeout=180.0, - experimental=True)["project"][method] - batch = res['batch'] - batch['size'] = res['batch']['size'] - return Entity.Batch(self.client, - self.uid, - batch, - failed_data_row_ids=res['failedDataRowIds']) - - def _create_batch_async(self, - name: str, - dr_ids: Optional[List[str]] = None, - global_keys: Optional[List[str]] = None, - priority: int = 5, - consensus_settings: Optional[Dict[str, - float]] = None): - method = 'createEmptyBatch' + res = self.client.execute( + query_str, params, timeout=180.0, experimental=True + )["project"][method] + batch = res["batch"] + batch["size"] = res["batch"]["size"] + return Entity.Batch( + self.client, + self.uid, + batch, + failed_data_row_ids=res["failedDataRowIds"], + ) + + def _create_batch_async( + self, + name: str, + dr_ids: Optional[List[str]] = None, + global_keys: Optional[List[str]] = None, + priority: int = 5, + consensus_settings: Optional[Dict[str, float]] = None, + ): + method = "createEmptyBatch" create_empty_batch_mutation_str = """mutation %sPyApi($projectId: ID!, $input: CreateEmptyBatchInput!) { project(where: {id: $projectId}) { %s(input: $input) { @@ -1097,19 +1203,18 @@ def _create_batch_async(self, params = { "projectId": self.uid, - "input": { - "name": name, - "consensusSettings": consensus_settings - } + "input": {"name": name, "consensusSettings": consensus_settings}, } - res = self.client.execute(create_empty_batch_mutation_str, - params, - timeout=180.0, - experimental=True)["project"][method] - batch_id = res['id'] + res = self.client.execute( + create_empty_batch_mutation_str, + params, + timeout=180.0, + experimental=True, + )["project"][method] + batch_id = res["id"] - method = 'addDataRowsToBatchAsync' + method = "addDataRowsToBatchAsync" add_data_rows_mutation_str = """mutation %sPyApi($projectId: ID!, $input: AddDataRowsToBatchInput!) { project(where: {id: $projectId}) { %s(input: $input) { @@ -1126,20 +1231,21 @@ def _create_batch_async(self, "dataRowIds": dr_ids, "globalKeys": global_keys, "priority": priority, - } + }, } - res = self.client.execute(add_data_rows_mutation_str, - params, - timeout=180.0, - experimental=True)["project"][method] + res = self.client.execute( + add_data_rows_mutation_str, params, timeout=180.0, experimental=True + )["project"][method] - task_id = res['taskId'] + task_id = res["taskId"] task = self._wait_for_task(task_id) if task.status != "COMPLETE": - raise LabelboxError(f"Batch was not created successfully: " + - json.dumps(task.errors)) + raise LabelboxError( + f"Batch was not created successfully: " + + json.dumps(task.errors) + ) return self.client.get_batch(self.uid, batch_id) @@ -1173,21 +1279,24 @@ def _update_queue_mode(self, mode: "QueueMode") -> "QueueMode": status = "DISABLED" else: raise ValueError( - "Must provide either `BATCH` or `DATASET` as a mode") + "Must provide either `BATCH` or `DATASET` as a mode" + ) - query_str = """mutation %s($projectId: ID!, $status: TagSetStatusInput!) { + query_str = ( + """mutation %s($projectId: ID!, $status: TagSetStatusInput!) { project(where: {id: $projectId}) { setTagSetStatus(input: {tagSetStatus: $status}) { tagSetStatus } } } - """ % "setTagSetStatusPyApi" + """ + % "setTagSetStatusPyApi" + ) - self.client.execute(query_str, { - 'projectId': self.uid, - 'status': status - }) + self.client.execute( + query_str, {"projectId": self.uid, "status": status} + ) return mode @@ -1202,7 +1311,7 @@ def get_label_count(self) -> int: } }""" - res = self.client.execute(query_str, {'projectId': self.uid}) + res = self.client.execute(query_str, {"projectId": self.uid}) return res["project"]["labelCount"] def get_queue_mode(self) -> "QueueMode": @@ -1221,17 +1330,22 @@ def get_queue_mode(self) -> "QueueMode": logger.warning( "Obtaining the queue_mode for a project through this method will soon" - " no longer be supported.") + " no longer be supported." + ) - query_str = """query %s($projectId: ID!) { + query_str = ( + """query %s($projectId: ID!) { project(where: {id: $projectId}) { tagSetStatus } } - """ % "GetTagSetStatusPyApi" + """ + % "GetTagSetStatusPyApi" + ) - status = self.client.execute( - query_str, {'projectId': self.uid})["project"]["tagSetStatus"] + status = self.client.execute(query_str, {"projectId": self.uid})[ + "project" + ]["tagSetStatus"] if status == "ENABLED": return QueueMode.Batch @@ -1241,7 +1355,7 @@ def get_queue_mode(self) -> "QueueMode": raise ValueError("Status not known") def add_model_config(self, model_config_id: str) -> str: - """ Adds a model config to this project. + """Adds a model config to this project. Args: model_config_id (str): ID of a model config to add to this project. @@ -1264,10 +1378,11 @@ def add_model_config(self, model_config_id: str) -> str: result = self.client.execute(query, params) except LabelboxError as e: if e.message.startswith( - "Unknown error: " + "Unknown error: " ): # unfortunate hack to handle unparsed graphql errors error_content = error_message_for_unparsed_graphql_error( - e.message) + e.message + ) else: error_content = e.message raise LabelboxError(message=error_content) from e @@ -1277,7 +1392,7 @@ def add_model_config(self, model_config_id: str) -> str: return result["createProjectModelConfig"]["projectModelConfigId"] def delete_project_model_config(self, project_model_config_id: str) -> bool: - """ Deletes the association between a model config and this project. + """Deletes the association between a model config and this project. Args: project_model_config_id (str): ID of a project model config association to delete for this project. @@ -1319,12 +1434,14 @@ def set_project_model_setup_complete(self) -> bool: result = self.client.execute(query, {"projectId": self.uid}) self.model_setup_complete = result["setProjectModelSetupComplete"][ - "modelSetupComplete"] + "modelSetupComplete" + ] return result["setProjectModelSetupComplete"]["modelSetupComplete"] def set_labeling_parameter_overrides( - self, data: List[LabelingParameterOverrideInput]) -> bool: - """ Adds labeling parameter overrides to this project. + self, data: List[LabelingParameterOverrideInput] + ) -> bool: + """Adds labeling parameter overrides to this project. See information on priority here: https://docs.labelbox.com/en/configure-editor/queue-system#reservation-system @@ -1364,22 +1481,25 @@ def set_labeling_parameter_overrides( {setLabelingParameterOverrides (dataWithDataRowIdentifiers: [$dataWithDataRowIdentifiers]) {success}}} - """) + """ + ) data_rows_with_identifiers = "" for data_row, priority in data: if isinstance(data_row, DataRow): - data_rows_with_identifiers += f"{{dataRowIdentifier: {{id: \"{data_row.uid}\", idType: {IdType.DataRowId}}}, priority: {priority}}}," + data_rows_with_identifiers += f'{{dataRowIdentifier: {{id: "{data_row.uid}", idType: {IdType.DataRowId}}}, priority: {priority}}},' elif isinstance(data_row, UniqueId) or isinstance( - data_row, GlobalKey): - data_rows_with_identifiers += f"{{dataRowIdentifier: {{id: \"{data_row.key}\", idType: {data_row.id_type}}}, priority: {priority}}}," + data_row, GlobalKey + ): + data_rows_with_identifiers += f'{{dataRowIdentifier: {{id: "{data_row.key}", idType: {data_row.id_type}}}, priority: {priority}}},' else: raise TypeError( f"Data row identifier should be be of type DataRow or Data Row Identifier. Found {type(data_row)}." ) query_str = template.substitute( - dataWithDataRowIdentifiers=data_rows_with_identifiers) + dataWithDataRowIdentifiers=data_rows_with_identifiers + ) res = self.client.execute(query_str, {"projectId": self.uid}) return res["project"]["setLabelingParameterOverrides"]["success"] @@ -1422,8 +1542,10 @@ def update_data_row_labeling_priority( if isinstance(data_rows, list): data_rows = UniqueIds(data_rows) - warnings.warn("Using data row ids will be deprecated. Please use " - "UniqueIds or GlobalKeys instead.") + warnings.warn( + "Using data row ids will be deprecated. Please use " + "UniqueIds or GlobalKeys instead." + ) method = "createQueuePriorityUpdateTask" priority_param = "priority" @@ -1442,28 +1564,40 @@ def update_data_row_labeling_priority( } } } - """ % (method, priority_param, project_param, data_rows_param, - project_param, method, priority_param, data_rows_param) + """ % ( + method, + priority_param, + project_param, + data_rows_param, + project_param, + method, + priority_param, + data_rows_param, + ) res = self.client.execute( - query_str, { + query_str, + { priority_param: priority, project_param: self.uid, data_rows_param: { "ids": [id for id in data_rows], "idType": data_rows.id_type, }, - })["project"][method] + }, + )["project"][method] - task_id = res['taskId'] + task_id = res["taskId"] task = self._wait_for_task(task_id) if task.status != "COMPLETE": - raise LabelboxError(f"Priority was not updated successfully: " + - json.dumps(task.errors)) + raise LabelboxError( + f"Priority was not updated successfully: " + + json.dumps(task.errors) + ) return True def extend_reservations(self, queue_type) -> int: - """ Extends all the current reservations for the current user on the given + """Extends all the current reservations for the current user on the given queue type. Args: queue_type (str): Either "LabelingQueue" or "ReviewQueue" @@ -1476,12 +1610,15 @@ def extend_reservations(self, queue_type) -> int: id_param = "projectId" query_str = """mutation ExtendReservationsPyApi($%s: ID!){ extendReservations(projectId:$%s queueType:%s)}""" % ( - id_param, id_param, queue_type) + id_param, + id_param, + queue_type, + ) res = self.client.execute(query_str, {id_param: self.uid}) return res["extendReservations"] def enable_model_assisted_labeling(self, toggle: bool = True) -> bool: - """ Turns model assisted labeling either on or off based on input + """Turns model assisted labeling either on or off based on input Args: toggle (bool): True or False boolean @@ -1503,10 +1640,11 @@ def enable_model_assisted_labeling(self, toggle: bool = True) -> bool: res = self.client.execute(query_str, params) return res["project"]["showPredictionsToLabelers"][ - "showingPredictionsToLabelers"] + "showingPredictionsToLabelers" + ] def bulk_import_requests(self) -> PaginatedCollection: - """ Returns bulk import request objects which are used in model-assisted labeling. + """Returns bulk import request objects which are used in model-assisted labeling. These are returned with the oldest first, and most recent last. """ @@ -1519,15 +1657,21 @@ def bulk_import_requests(self) -> PaginatedCollection: ) { %s } - }""" % (id_param, id_param, - query.results_query_part(Entity.BulkImportRequest)) - return PaginatedCollection(self.client, query_str, - {id_param: str(self.uid)}, - ["bulkImportRequests"], - Entity.BulkImportRequest) + }""" % ( + id_param, + id_param, + query.results_query_part(Entity.BulkImportRequest), + ) + return PaginatedCollection( + self.client, + query_str, + {id_param: str(self.uid)}, + ["bulkImportRequests"], + Entity.BulkImportRequest, + ) def batches(self) -> PaginatedCollection: - """ Fetch all batches that belong to this project + """Fetch all batches that belong to this project Returns: A `PaginatedCollection` of `Batch`es @@ -1539,13 +1683,16 @@ def batches(self) -> PaginatedCollection: """ % (id_param, id_param, query.results_query_part(Entity.Batch)) return PaginatedCollection( self.client, - query_str, {id_param: self.uid}, ['project', 'batches', 'nodes'], + query_str, + {id_param: self.uid}, + ["project", "batches", "nodes"], lambda client, res: Entity.Batch(client, self.uid, res), - cursor_path=['project', 'batches', 'pageInfo', 'endCursor'], - experimental=True) + cursor_path=["project", "batches", "pageInfo", "endCursor"], + experimental=True, + ) def task_queues(self) -> List[TaskQueue]: - """ Fetch all task queues that belong to this project + """Fetch all task queues that belong to this project Returns: A `List` of `TaskQueue`s @@ -1560,9 +1707,8 @@ def task_queues(self) -> List[TaskQueue]: """ % (query.results_query_part(Entity.TaskQueue)) task_queue_values = self.client.execute( - query_str, {"projectId": self.uid}, - timeout=180.0, - experimental=True)["project"]["taskQueues"] + query_str, {"projectId": self.uid}, timeout=180.0, experimental=True + )["project"]["taskQueues"] return [ Entity.TaskQueue(self.client, field_values) @@ -1570,13 +1716,15 @@ def task_queues(self) -> List[TaskQueue]: ] @overload - def move_data_rows_to_task_queue(self, data_row_ids: DataRowIdentifiers, - task_queue_id: str): + def move_data_rows_to_task_queue( + self, data_row_ids: DataRowIdentifiers, task_queue_id: str + ): pass @overload - def move_data_rows_to_task_queue(self, data_row_ids: List[str], - task_queue_id: str): + def move_data_rows_to_task_queue( + self, data_row_ids: List[str], task_queue_id: str + ): pass def move_data_rows_to_task_queue(self, data_row_ids, task_queue_id: str): @@ -1595,11 +1743,14 @@ def move_data_rows_to_task_queue(self, data_row_ids, task_queue_id: str): """ if isinstance(data_row_ids, list): data_row_ids = UniqueIds(data_row_ids) - warnings.warn("Using data row ids will be deprecated. Please use " - "UniqueIds or GlobalKeys instead.") + warnings.warn( + "Using data row ids will be deprecated. Please use " + "UniqueIds or GlobalKeys instead." + ) method = "createBulkAddRowsToQueueTask" - query_str = """mutation AddDataRowsToTaskQueueAsyncPyApi( + query_str = ( + """mutation AddDataRowsToTaskQueueAsyncPyApi( $projectId: ID! $queueId: ID $dataRowIdentifiers: AddRowsToTaskQueueViaDataRowIdentifiersInput! @@ -1612,10 +1763,13 @@ def move_data_rows_to_task_queue(self, data_row_ids, task_queue_id: str): } } } - """ % method + """ + % method + ) task_id = self.client.execute( - query_str, { + query_str, + { "projectId": self.uid, "queueId": task_queue_id, "dataRowIdentifiers": { @@ -1624,12 +1778,15 @@ def move_data_rows_to_task_queue(self, data_row_ids, task_queue_id: str): }, }, timeout=180.0, - experimental=True)["project"][method]["taskId"] + experimental=True, + )["project"][method]["taskId"] task = self._wait_for_task(task_id) if task.status != "COMPLETE": - raise LabelboxError(f"Data rows were not moved successfully: " + - json.dumps(task.errors)) + raise LabelboxError( + f"Data rows were not moved successfully: " + + json.dumps(task.errors) + ) def _wait_for_task(self, task_id: str) -> Task: task = Task.get_task(self.client, task_id) @@ -1638,11 +1795,12 @@ def _wait_for_task(self, task_id: str) -> Task: return task def upload_annotations( - self, - name: str, - annotations: Union[str, Path, Iterable[Dict]], - validate: bool = False) -> 'BulkImportRequest': # type: ignore - """ Uploads annotations to a new Editor project. + self, + name: str, + annotations: Union[str, Path, Iterable[Dict]], + validate: bool = False, + ) -> "BulkImportRequest": # type: ignore + """Uploads annotations to a new Editor project. Args: name (str): name of the BulkImportRequest job @@ -1660,7 +1818,7 @@ def upload_annotations( if isinstance(annotations, str) or isinstance(annotations, Path): def _is_url_valid(url: Union[str, Path]) -> bool: - """ Verifies that the given string is a valid url. + """Verifies that the given string is a valid url. Args: url: string to be checked @@ -1679,12 +1837,13 @@ def _is_url_valid(url: Union[str, Path]) -> bool: project_id=self.uid, name=name, url=str(annotations), - validate=validate) + validate=validate, + ) else: path = Path(annotations) if not path.exists(): raise FileNotFoundError( - f'{annotations} is not a valid url nor existing local file' + f"{annotations} is not a valid url nor existing local file" ) return Entity.BulkImportRequest.create_from_local_file( client=self.client, @@ -1699,64 +1858,79 @@ def _is_url_valid(url: Union[str, Path]) -> bool: project_id=self.uid, name=name, predictions=annotations, # type: ignore - validate=validate) + validate=validate, + ) else: raise ValueError( - f'Invalid annotations given of type: {type(annotations)}') + f"Invalid annotations given of type: {type(annotations)}" + ) def _wait_until_data_rows_are_processed( - self, - data_row_ids: Optional[List[str]] = None, - global_keys: Optional[List[str]] = None, - wait_processing_max_seconds: int = _wait_processing_max_seconds, - sleep_interval=30): - """ Wait until all the specified data rows are processed""" + self, + data_row_ids: Optional[List[str]] = None, + global_keys: Optional[List[str]] = None, + wait_processing_max_seconds: int = _wait_processing_max_seconds, + sleep_interval=30, + ): + """Wait until all the specified data rows are processed""" start_time = datetime.now() max_data_rows_per_poll = 100_000 if data_row_ids is not None: for i in range(0, len(data_row_ids), max_data_rows_per_poll): - chunk = data_row_ids[i:i + max_data_rows_per_poll] + chunk = data_row_ids[i : i + max_data_rows_per_poll] self._poll_data_row_processing_status( - chunk, [], start_time, wait_processing_max_seconds, - sleep_interval) + chunk, + [], + start_time, + wait_processing_max_seconds, + sleep_interval, + ) if global_keys is not None: for i in range(0, len(global_keys), max_data_rows_per_poll): - chunk = global_keys[i:i + max_data_rows_per_poll] + chunk = global_keys[i : i + max_data_rows_per_poll] self._poll_data_row_processing_status( - [], chunk, start_time, wait_processing_max_seconds, - sleep_interval) + [], + chunk, + start_time, + wait_processing_max_seconds, + sleep_interval, + ) def _poll_data_row_processing_status( - self, - data_row_ids: List[str], - global_keys: List[str], - start_time: datetime, - wait_processing_max_seconds: int = _wait_processing_max_seconds, - sleep_interval=30): - + self, + data_row_ids: List[str], + global_keys: List[str], + start_time: datetime, + wait_processing_max_seconds: int = _wait_processing_max_seconds, + sleep_interval=30, + ): while True: - if (datetime.now() - - start_time).total_seconds() >= wait_processing_max_seconds: + if ( + datetime.now() - start_time + ).total_seconds() >= wait_processing_max_seconds: raise ProcessingWaitTimeout( """Maximum wait time exceeded while waiting for data rows to be processed. - Try creating a batch a bit later""") + Try creating a batch a bit later""" + ) all_good = self.__check_data_rows_have_been_processed( - data_row_ids, global_keys) + data_row_ids, global_keys + ) if all_good: return logger.debug( - 'Some of the data rows are still being processed, waiting...') + "Some of the data rows are still being processed, waiting..." + ) time.sleep(sleep_interval) def __check_data_rows_have_been_processed( - self, - data_row_ids: Optional[List[str]] = None, - global_keys: Optional[List[str]] = None): - + self, + data_row_ids: Optional[List[str]] = None, + global_keys: Optional[List[str]] = None, + ): if data_row_ids is not None and len(data_row_ids) > 0: param_name = "dataRowIds" params = {param_name: data_row_ids} @@ -1773,11 +1947,12 @@ def __check_data_rows_have_been_processed( response = self.client.execute(query_str, params) return response["queryAllDataRowsHaveBeenProcessed"][ - "allDataRowsHaveBeenProcessed"] + "allDataRowsHaveBeenProcessed" + ] def get_overview( - self, - details=False) -> Union[ProjectOverview, ProjectOverviewDetailed]: + self, details=False + ) -> Union[ProjectOverview, ProjectOverviewDetailed]: """Return the overview of a project. This method returns the number of data rows per task queue and issues of a project, @@ -1816,8 +1991,9 @@ def get_overview( """ # Must use experimental to access "issues" - result = self.client.execute(query, {"projectId": self.uid}, - experimental=True)["project"] + result = self.client.execute( + query, {"projectId": self.uid}, experimental=True + )["project"] # Reformat category names overview = { @@ -1838,16 +2014,14 @@ def get_overview( # Build dictionary for queue details for review and rework queues for category in ["rework", "review"]: queues = [ - { - tq["name"]: tq.get("dataRowCount") - } + {tq["name"]: tq.get("dataRowCount")} for tq in result.get("taskQueues") if tq.get("queueType") == f"MANUAL_{category.upper()}_QUEUE" ] overview[f"in_{category}"] = { "data": queues, - "total": overview[f"in_{category}"] + "total": overview[f"in_{category}"], } return ProjectOverviewDetailed(**overview) @@ -1897,7 +2071,7 @@ def get_labeling_service_dashboard(self) -> LabelingServiceDashboard: """Get the labeling service for this project. Returns: - LabelingServiceDashboard: The labeling service for this project. + LabelingServiceDashboard: The labeling service for this project. Attributes of the dashboard include: id (str): The project id. @@ -1927,12 +2101,13 @@ class ProjectMember(DbObject): class LabelingParameterOverride(DbObject): - """ Customizes the order of assets in the label queue. + """Customizes the order of assets in the label queue. Attributes: priority (int): A prioritization score. number_of_labels (int): Number of times an asset should be labeled. """ + priority = Field.Int("priority") number_of_labels = Field.Int("number_of_labels") @@ -1940,8 +2115,10 @@ class LabelingParameterOverride(DbObject): LabelerPerformance = namedtuple( - "LabelerPerformance", "user count seconds_per_label, total_time_labeling " - "consensus average_benchmark_agreement last_activity_time") + "LabelerPerformance", + "user count seconds_per_label, total_time_labeling " + "consensus average_benchmark_agreement last_activity_time", +) LabelerPerformance.__doc__ = ( - "Named tuple containing info about a labeler's performance.") - + "Named tuple containing info about a labeler's performance." +) diff --git a/libs/labelbox/src/labelbox/schema/project_model_config.py b/libs/labelbox/src/labelbox/schema/project_model_config.py index 9cf6dcbfa..9b6d8a0bb 100644 --- a/libs/labelbox/src/labelbox/schema/project_model_config.py +++ b/libs/labelbox/src/labelbox/schema/project_model_config.py @@ -1,12 +1,15 @@ from labelbox.orm.db_object import DbObject from labelbox.orm.model import Field, Relationship -from labelbox.exceptions import LabelboxError, error_message_for_unparsed_graphql_error +from labelbox.exceptions import ( + LabelboxError, + error_message_for_unparsed_graphql_error, +) class ProjectModelConfig(DbObject): - """ A ProjectModelConfig represents an association between a project and a single model config. + """A ProjectModelConfig represents an association between a project and a single model config. - Attributes: + Attributes: project_id (str): ID of project to associate model_config_id (str): ID of the model configuration model_config (ModelConfig): Configuration for model @@ -17,7 +20,7 @@ class ProjectModelConfig(DbObject): model_config = Relationship.ToOne("ModelConfig", False, "model_config") def delete(self) -> bool: - """ Deletes this association between a model config and this project. + """Deletes this association between a model config and this project. Returns: bool, indicates if the operation was a success. @@ -36,10 +39,11 @@ def delete(self) -> bool: result = self.client.execute(query, params) except LabelboxError as e: if e.message.startswith( - "Unknown error: " + "Unknown error: " ): # unfortunate hack to handle unparsed graphql errors error_content = error_message_for_unparsed_graphql_error( - e.message) + e.message + ) else: error_content = e.message raise LabelboxError(message=error_content) from e diff --git a/libs/labelbox/src/labelbox/schema/project_overview.py b/libs/labelbox/src/labelbox/schema/project_overview.py index 9f6c31e02..cee195c10 100644 --- a/libs/labelbox/src/labelbox/schema/project_overview.py +++ b/libs/labelbox/src/labelbox/schema/project_overview.py @@ -2,9 +2,10 @@ from typing_extensions import TypedDict from pydantic import BaseModel + class ProjectOverview(BaseModel): """ - Class that represents a project summary as displayed in the UI, in Annotate, + Class that represents a project summary as displayed in the UI, in Annotate, under the "Overview" tab of a particular project. All attributes represent the number of data rows in the corresponding state. @@ -19,7 +20,8 @@ class ProjectOverview(BaseModel): The `labeled` attribute represents the number of data rows that have been labeled. The `total_data_rows` attribute represents the total number of data rows in the project. """ - to_label: int + + to_label: int in_review: int in_rework: int skipped: int @@ -32,16 +34,17 @@ class ProjectOverview(BaseModel): class _QueueDetail(TypedDict): """ Class that represents the detailed information of the queues in the project overview. - The `data` attribute is a list of dictionaries where the keys are the queue names + The `data` attribute is a list of dictionaries where the keys are the queue names and the values are the number of data rows in that queue. """ + data: List[Dict[str, int]] total: int - + class ProjectOverviewDetailed(BaseModel): """ - Class that represents a project summary as displayed in the UI, in Annotate, + Class that represents a project summary as displayed in the UI, in Annotate, under the "Overview" tab of a particular project. This class adds the list of task queues for the `in_review` and `in_rework` attributes. @@ -62,11 +65,11 @@ class ProjectOverviewDetailed(BaseModel): The `total_data_rows` attribute represents the total number of data rows in the project. """ - to_label: int + to_label: int in_review: _QueueDetail in_rework: _QueueDetail skipped: int done: int issues: int labeled: int - total_data_rows: int \ No newline at end of file + total_data_rows: int diff --git a/libs/labelbox/src/labelbox/schema/project_resource_tag.py b/libs/labelbox/src/labelbox/schema/project_resource_tag.py index bfb024c5a..18ca94860 100644 --- a/libs/labelbox/src/labelbox/schema/project_resource_tag.py +++ b/libs/labelbox/src/labelbox/schema/project_resource_tag.py @@ -3,7 +3,7 @@ class ProjectResourceTag(DbObject, Updateable): - """ Project resource tag to associate ProjectResourceTag to Project. + """Project resource tag to associate ProjectResourceTag to Project. Attributes: resourceTagId (str) diff --git a/libs/labelbox/src/labelbox/schema/resource_tag.py b/libs/labelbox/src/labelbox/schema/resource_tag.py index b1f5d6e62..8c0559486 100644 --- a/libs/labelbox/src/labelbox/schema/resource_tag.py +++ b/libs/labelbox/src/labelbox/schema/resource_tag.py @@ -3,7 +3,7 @@ class ResourceTag(DbObject, Updateable): - """ Resource tag to label and identify your labelbox resources easier. + """Resource tag to label and identify your labelbox resources easier. Attributes: text (str) diff --git a/libs/labelbox/src/labelbox/schema/review.py b/libs/labelbox/src/labelbox/schema/review.py index a9ae6d9ae..9a6850a28 100644 --- a/libs/labelbox/src/labelbox/schema/review.py +++ b/libs/labelbox/src/labelbox/schema/review.py @@ -5,7 +5,7 @@ class Review(DbObject, Deletable, Updateable): - """ Reviewing labeled data is a collaborative quality assurance technique. + """Reviewing labeled data is a collaborative quality assurance technique. A Review object indicates the quality of the assigned Label. The aggregated review numbers can be obtained on a Project object. @@ -22,8 +22,8 @@ class Review(DbObject, Deletable, Updateable): """ class NetScore(Enum): - """ Negative, Zero, or Positive. - """ + """Negative, Zero, or Positive.""" + Negative = auto() Zero = auto() Positive = auto() diff --git a/libs/labelbox/src/labelbox/schema/role.py b/libs/labelbox/src/labelbox/schema/role.py index 90930fab9..47cd753e9 100644 --- a/libs/labelbox/src/labelbox/schema/role.py +++ b/libs/labelbox/src/labelbox/schema/role.py @@ -16,26 +16,24 @@ def get_roles(client: "Client") -> Dict[str, "Role"]: query_str = """query GetAvailableUserRolesPyApi { roles { id name } }""" res = client.execute(query_str) _ROLES = {} - for role in res['roles']: - role['name'] = format_role(role['name']) - _ROLES[role['name']] = Role(client, role) + for role in res["roles"]: + role["name"] = format_role(role["name"]) + _ROLES[role["name"]] = Role(client, role) return _ROLES def format_role(name: str): - return name.upper().replace(' ', '_') + return name.upper().replace(" ", "_") class Role(DbObject): name = Field.String("name") -class OrgRole(Role): - ... +class OrgRole(Role): ... -class UserRole(Role): - ... +class UserRole(Role): ... @dataclass diff --git a/libs/labelbox/src/labelbox/schema/search_filters.py b/libs/labelbox/src/labelbox/schema/search_filters.py index f2ca7beae..13b158678 100644 --- a/libs/labelbox/src/labelbox/schema/search_filters.py +++ b/libs/labelbox/src/labelbox/schema/search_filters.py @@ -24,15 +24,16 @@ class OperationTypeEnum(Enum): Supported search entity types Each type corresponds to a different filter class """ - Organization = 'organization_id' - SharedWithOrganization = 'shared_with_organizations' - Workspace = 'workspace' - Tag = 'tag' - Stage = 'stage' - WorforceRequestedDate = 'workforce_requested_at' - WorkforceStageUpdatedDate = 'workforce_stage_updated_at' - TaskCompletedCount = 'task_completed_count' - TaskRemainingCount = 'task_remaining_count' + + Organization = "organization_id" + SharedWithOrganization = "shared_with_organizations" + Workspace = "workspace" + Tag = "tag" + Stage = "stage" + WorforceRequestedDate = "workforce_requested_at" + WorkforceStageUpdatedDate = "workforce_stage_updated_at" + TaskCompletedCount = "task_completed_count" + TaskRemainingCount = "task_remaining_count" def convert_enum_to_str(enum_or_str: Union[Enum, str]) -> str: @@ -41,50 +42,58 @@ def convert_enum_to_str(enum_or_str: Union[Enum, str]) -> str: return enum_or_str -OperationType = Annotated[OperationTypeEnum, - PlainSerializer(convert_enum_to_str, return_type=str)] +OperationType = Annotated[ + OperationTypeEnum, PlainSerializer(convert_enum_to_str, return_type=str) +] -IsoDatetimeType = Annotated[datetime.datetime, - PlainSerializer(format_iso_datetime)] +IsoDatetimeType = Annotated[ + datetime.datetime, PlainSerializer(format_iso_datetime) +] class IdOperator(Enum): """ Supported operators for ids like org ids, workspace ids, etc """ - Is = 'is' + + Is = "is" class RangeOperatorWithSingleValue(Enum): """ Supported operators for dates """ - Equals = 'EQUALS' - GreaterThanOrEqual = 'GREATER_THAN_OR_EQUAL' - LessThanOrEqual = 'LESS_THAN_OR_EQUAL' + + Equals = "EQUALS" + GreaterThanOrEqual = "GREATER_THAN_OR_EQUAL" + LessThanOrEqual = "LESS_THAN_OR_EQUAL" class RangeDateTimeOperatorWithSingleValue(Enum): """ Supported operators for dates """ - GreaterThanOrEqual = 'GREATER_THAN_OR_EQUAL' - LessThanOrEqual = 'LESS_THAN_OR_EQUAL' + + GreaterThanOrEqual = "GREATER_THAN_OR_EQUAL" + LessThanOrEqual = "LESS_THAN_OR_EQUAL" class RangeOperatorWithValue(Enum): """ Supported operators for date ranges """ - Between = 'BETWEEN' + + Between = "BETWEEN" class OrganizationFilter(BaseSearchFilter): """ Filter for organization to which projects belong """ - operation: OperationType = Field(default=OperationType.Organization, - serialization_alias='type') + + operation: OperationType = Field( + default=OperationType.Organization, serialization_alias="type" + ) operator: IdOperator values: List[str] @@ -95,8 +104,8 @@ class SharedWithOrganizationFilter(BaseSearchFilter): """ operation: OperationType = Field( - default=OperationType.SharedWithOrganization, - serialization_alias='type') + default=OperationType.SharedWithOrganization, serialization_alias="type" + ) operator: IdOperator values: List[str] @@ -105,8 +114,10 @@ class WorkspaceFilter(BaseSearchFilter): """ Filter for workspace """ - operation: OperationType = Field(default=OperationType.Workspace, - serialization_alias='type') + + operation: OperationType = Field( + default=OperationType.Workspace, serialization_alias="type" + ) operator: IdOperator values: List[str] @@ -116,8 +127,10 @@ class TagFilter(BaseSearchFilter): Filter for project tags values are tag ids """ - operation: OperationType = Field(default=OperationType.Tag, - serialization_alias='type') + + operation: OperationType = Field( + default=OperationType.Tag, serialization_alias="type" + ) operator: IdOperator values: List[str] @@ -127,18 +140,21 @@ class ProjectStageFilter(BaseSearchFilter): Filter labelbox service / aka project stages Stages are: requested, in_progress, completed etc. as described by LabelingServiceStatus """ - operation: OperationType = Field(default=OperationType.Stage, - serialization_alias='type') + + operation: OperationType = Field( + default=OperationType.Stage, serialization_alias="type" + ) operator: IdOperator values: List[LabelingServiceStatus] - @field_validator('values', mode='before') + @field_validator("values", mode="before") def validate_values(cls, values): disallowed_values = [LabelingServiceStatus.Missing] for value in values: if value in disallowed_values: raise ValueError( - f"{value} is not a valid value for ProjectStageFilter") + f"{value} is not a valid value for ProjectStageFilter" + ) return values @@ -155,6 +171,7 @@ class DateValue(BaseSearchFilter): so for a string '2024-01-01' that is run on a computer in PST, we would convert it to '2024-01-01T08:00:00Z' while the same string in EST will get converted to '2024-01-01T05:00:00Z' """ + operator: RangeDateTimeOperatorWithSingleValue value: IsoDatetimeType @@ -168,9 +185,11 @@ class WorkforceStageUpdatedFilter(BaseSearchFilter): """ Filter for workforce stage updated date """ + operation: OperationType = Field( default=OperationType.WorkforceStageUpdatedDate, - serialization_alias='type') + serialization_alias="type", + ) value: DateValue @@ -178,8 +197,10 @@ class WorkforceRequestedDateFilter(BaseSearchFilter): """ Filter for workforce requested date """ + operation: OperationType = Field( - default=OperationType.WorforceRequestedDate, serialization_alias='type') + default=OperationType.WorforceRequestedDate, serialization_alias="type" + ) value: DateValue @@ -187,14 +208,16 @@ class DateRange(BaseSearchFilter): """ Date range for a search filter """ + min: IsoDatetimeType max: IsoDatetimeType class DateRangeValue(BaseSearchFilter): """ - Date range value for a search filter + Date range value for a search filter """ + operator: RangeOperatorWithValue value: DateRange @@ -203,8 +226,10 @@ class WorkforceRequestedDateRangeFilter(BaseSearchFilter): """ Filter for workforce requested date range """ + operation: OperationType = Field( - default=OperationType.WorforceRequestedDate, serialization_alias='type') + default=OperationType.WorforceRequestedDate, serialization_alias="type" + ) value: DateRangeValue @@ -212,9 +237,11 @@ class WorkforceStageUpdatedRangeFilter(BaseSearchFilter): """ Filter for workforce stage updated date range """ + operation: OperationType = Field( default=OperationType.WorkforceStageUpdatedDate, - serialization_alias='type') + serialization_alias="type", + ) value: DateRangeValue @@ -223,8 +250,10 @@ class TaskCompletedCountFilter(BaseSearchFilter): Filter for completed tasks count A task maps to a data row. Task completed should map to a data row in a labeling queue DONE """ - operation: OperationType = Field(default=OperationType.TaskCompletedCount, - serialization_alias='type') + + operation: OperationType = Field( + default=OperationType.TaskCompletedCount, serialization_alias="type" + ) value: IntegerValue @@ -232,27 +261,41 @@ class TaskRemainingCountFilter(BaseSearchFilter): """ Filter for remaining tasks count. Reverse of TaskCompletedCountFilter """ - operation: OperationType = Field(default=OperationType.TaskRemainingCount, - serialization_alias='type') + + operation: OperationType = Field( + default=OperationType.TaskRemainingCount, serialization_alias="type" + ) value: IntegerValue -SearchFilter = Union[OrganizationFilter, WorkspaceFilter, - SharedWithOrganizationFilter, TagFilter, - ProjectStageFilter, WorkforceRequestedDateFilter, - WorkforceStageUpdatedFilter, - WorkforceRequestedDateRangeFilter, - WorkforceStageUpdatedRangeFilter, TaskCompletedCountFilter, - TaskRemainingCountFilter] +SearchFilter = Union[ + OrganizationFilter, + WorkspaceFilter, + SharedWithOrganizationFilter, + TagFilter, + ProjectStageFilter, + WorkforceRequestedDateFilter, + WorkforceStageUpdatedFilter, + WorkforceRequestedDateRangeFilter, + WorkforceStageUpdatedRangeFilter, + TaskCompletedCountFilter, + TaskRemainingCountFilter, +] def _dict_to_graphql_string(d: Union[dict, list, str, int]) -> str: if isinstance(d, dict): - return "{" + ", ".join( - f'{k}: {_dict_to_graphql_string(v)}' for k, v in d.items()) + "}" + return ( + "{" + + ", ".join( + f"{k}: {_dict_to_graphql_string(v)}" for k, v in d.items() + ) + + "}" + ) elif isinstance(d, list): - return "[" + ", ".join( - _dict_to_graphql_string(item) for item in d) + "]" + return ( + "[" + ", ".join(_dict_to_graphql_string(item) for item in d) + "]" + ) else: return f'"{d}"' if isinstance(d, str) else str(d) diff --git a/libs/labelbox/src/labelbox/schema/send_to_annotate_params.py b/libs/labelbox/src/labelbox/schema/send_to_annotate_params.py index f3636e14d..18bd26637 100644 --- a/libs/labelbox/src/labelbox/schema/send_to_annotate_params.py +++ b/libs/labelbox/src/labelbox/schema/send_to_annotate_params.py @@ -2,7 +2,9 @@ from typing import Optional, Dict -from labelbox.schema.conflict_resolution_strategy import ConflictResolutionStrategy +from labelbox.schema.conflict_resolution_strategy import ( + ConflictResolutionStrategy, +) if sys.version_info >= (3, 8): from typing import TypedDict @@ -37,22 +39,24 @@ class SendToAnnotateFromCatalogParams(BaseModel): predictions_ontology_mapping: Optional[Dict[str, str]] = {} annotations_ontology_mapping: Optional[Dict[str, str]] = {} exclude_data_rows_in_project: Optional[bool] = False - override_existing_annotations_rule: Optional[ - ConflictResolutionStrategy] = ConflictResolutionStrategy.KeepExisting + override_existing_annotations_rule: Optional[ConflictResolutionStrategy] = ( + ConflictResolutionStrategy.KeepExisting + ) batch_priority: Optional[int] = 5 @model_validator(mode="after") def check_project_id_or_model_run_id(self): if not self.source_model_run_id and not self.source_project_id: raise ValueError( - 'Either source_project_id or source_model_id are required' + "Either source_project_id or source_model_id are required" ) if self.source_model_run_id and self.source_project_id: raise ValueError( - 'Provide only a source_project_id or source_model_id not both' - ) + "Provide only a source_project_id or source_model_id not both" + ) return self + class SendToAnnotateFromModelParams(TypedDict): """ Extra parameters for sending data rows to a project through a model run. @@ -73,36 +77,35 @@ class SendToAnnotateFromModelParams(TypedDict): batch_priority: Optional[int] -def build_annotations_input(project_ontology_mapping: Optional[Dict[str, str]], - source_project_id: str): +def build_annotations_input( + project_ontology_mapping: Optional[Dict[str, str]], source_project_id: str +): return { - "projectId": - source_project_id, - "featureSchemaIdsMapping": - project_ontology_mapping if project_ontology_mapping else {}, + "projectId": source_project_id, + "featureSchemaIdsMapping": project_ontology_mapping + if project_ontology_mapping + else {}, } def build_destination_task_queue_input(task_queue_id: str): - destination_task_queue = { - "type": "id", - "value": task_queue_id - } if task_queue_id else { - "type": "done" - } + destination_task_queue = ( + {"type": "id", "value": task_queue_id} + if task_queue_id + else {"type": "done"} + ) return destination_task_queue -def build_predictions_input(model_run_ontology_mapping: Optional[Dict[str, - str]], - source_model_run_id: str): +def build_predictions_input( + model_run_ontology_mapping: Optional[Dict[str, str]], + source_model_run_id: str, +): return { - "featureSchemaIdsMapping": - model_run_ontology_mapping if model_run_ontology_mapping else {}, - "modelRunId": - source_model_run_id, - "minConfidence": - 0, - "maxConfidence": - 1 + "featureSchemaIdsMapping": model_run_ontology_mapping + if model_run_ontology_mapping + else {}, + "modelRunId": source_model_run_id, + "minConfidence": 0, + "maxConfidence": 1, } diff --git a/libs/labelbox/src/labelbox/schema/serialization.py b/libs/labelbox/src/labelbox/schema/serialization.py index cfbbb04f8..ca5537fd9 100644 --- a/libs/labelbox/src/labelbox/schema/serialization.py +++ b/libs/labelbox/src/labelbox/schema/serialization.py @@ -5,8 +5,8 @@ def serialize_labels( - objects: Union[List[Dict[str, Any]], - List["Label"]]) -> List[Dict[str, Any]]: + objects: Union[List[Dict[str, Any]], List["Label"]], +) -> List[Dict[str, Any]]: """ Checks if objects are of type Labels and serializes labels for annotation import. Serialization depends the labelbox[data] package, therefore NDJsonConverter is only loaded if using `Label` objects instead of `dict` objects. """ @@ -17,6 +17,7 @@ def serialize_labels( if is_label_type: # If a Label object exists, labelbox[data] is already installed, so no error checking is needed. from labelbox.data.serialization import NDJsonConverter + labels = cast(List["Label"], objects) return list(NDJsonConverter.serialize(labels)) diff --git a/libs/labelbox/src/labelbox/schema/slice.py b/libs/labelbox/src/labelbox/schema/slice.py index ffd1f2768..624731024 100644 --- a/libs/labelbox/src/labelbox/schema/slice.py +++ b/libs/labelbox/src/labelbox/schema/slice.py @@ -4,7 +4,10 @@ from labelbox.orm.db_object import DbObject, experimental from labelbox.orm.model import Field from labelbox.pagination import PaginatedCollection -from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params +from labelbox.schema.export_params import ( + CatalogExportParams, + validate_catalog_export_params, +) from labelbox.schema.export_task import ExportTask from labelbox.schema.identifiable import GlobalKey, UniqueId from labelbox.schema.task import Task @@ -41,7 +44,7 @@ def __init__(self, id: str, global_key: Optional[str]): def to_hash(self): return { "id": self.id.key, - "global_key": self.global_key.key if self.global_key else None + "global_key": self.global_key.key if self.global_key else None, } @@ -81,10 +84,11 @@ def get_data_row_ids(self) -> PaginatedCollection: return PaginatedCollection( client=self.client, query=query_str, - params={'id': str(self.uid)}, - dereferencing=['getDataRowIdsBySavedQuery', 'nodes'], + params={"id": str(self.uid)}, + dereferencing=["getDataRowIdsBySavedQuery", "nodes"], obj_class=lambda _, data_row_id: data_row_id, - cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor']) + cursor_path=["getDataRowIdsBySavedQuery", "pageInfo", "endCursor"], + ) def get_data_row_identifiers(self) -> PaginatedCollection: """ @@ -116,18 +120,24 @@ def get_data_row_identifiers(self) -> PaginatedCollection: return PaginatedCollection( client=self.client, query=query_str, - params={'id': str(self.uid)}, - dereferencing=['getDataRowIdentifiersBySavedQuery', 'nodes'], + params={"id": str(self.uid)}, + dereferencing=["getDataRowIdentifiersBySavedQuery", "nodes"], obj_class=lambda _, data_row_id_and_gk: Slice.DataRowIdAndGlobalKey( - data_row_id_and_gk.get('id'), - data_row_id_and_gk.get('globalKey', None)), + data_row_id_and_gk.get("id"), + data_row_id_and_gk.get("globalKey", None), + ), cursor_path=[ - 'getDataRowIdentifiersBySavedQuery', 'pageInfo', 'endCursor' - ]) + "getDataRowIdentifiersBySavedQuery", + "pageInfo", + "endCursor", + ], + ) - def export(self, - task_name: Optional[str] = None, - params: Optional[CatalogExportParams] = None) -> ExportTask: + def export( + self, + task_name: Optional[str] = None, + params: Optional[CatalogExportParams] = None, + ) -> ExportTask: """ Creates a slice export task with the given params and returns the task. >>> slice = client.get_catalog_slice("SLICE_ID") @@ -155,7 +165,7 @@ def export_v2( >>> task.result """ task, is_streamable = self._export(task_name, params) - if (is_streamable): + if is_streamable: return ExportTask(task, True) return task @@ -165,73 +175,70 @@ def _export( params: Optional[CatalogExportParams] = None, streamable: bool = False, ) -> Tuple[Task, bool]: - _params = params or CatalogExportParams({ - "attachments": False, - "embeddings": False, - "metadata_fields": False, - "data_row_details": False, - "project_details": False, - "performance_details": False, - "label_details": False, - "media_type_override": None, - "model_run_ids": None, - "project_ids": None, - "interpolated_frames": False, - "all_projects": False, - "all_model_runs": False, - }) + _params = params or CatalogExportParams( + { + "attachments": False, + "embeddings": False, + "metadata_fields": False, + "data_row_details": False, + "project_details": False, + "performance_details": False, + "label_details": False, + "media_type_override": None, + "model_run_ids": None, + "project_ids": None, + "interpolated_frames": False, + "all_projects": False, + "all_model_runs": False, + } + ) validate_catalog_export_params(_params) mutation_name = "exportDataRowsInSlice" create_task_query_str = ( f"mutation {mutation_name}PyApi" f"($input: ExportDataRowsInSliceInput!)" - f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}") + f"{{{mutation_name}(input: $input){{taskId isStreamable}}}}" + ) - media_type_override = _params.get('media_type_override', None) + media_type_override = _params.get("media_type_override", None) query_params = { "input": { "taskName": task_name, - "filters": { - "sliceId": self.uid - }, + "filters": {"sliceId": self.uid}, "isStreamableReady": True, "params": { - "mediaTypeOverride": - media_type_override.value - if media_type_override is not None else None, - "includeAttachments": - _params.get('attachments', False), - "includeEmbeddings": - _params.get('embeddings', False), - "includeMetadata": - _params.get('metadata_fields', False), - "includeDataRowDetails": - _params.get('data_row_details', False), - "includeProjectDetails": - _params.get('project_details', False), - "includePerformanceDetails": - _params.get('performance_details', False), - "includeLabelDetails": - _params.get('label_details', False), - "includeInterpolatedFrames": - _params.get('interpolated_frames', False), - "projectIds": - _params.get('project_ids', None), - "modelRunIds": - _params.get('model_run_ids', None), - "allProjects": - _params.get('all_projects', False), - "allModelRuns": - _params.get('all_model_runs', False), + "mediaTypeOverride": media_type_override.value + if media_type_override is not None + else None, + "includeAttachments": _params.get("attachments", False), + "includeEmbeddings": _params.get("embeddings", False), + "includeMetadata": _params.get("metadata_fields", False), + "includeDataRowDetails": _params.get( + "data_row_details", False + ), + "includeProjectDetails": _params.get( + "project_details", False + ), + "includePerformanceDetails": _params.get( + "performance_details", False + ), + "includeLabelDetails": _params.get("label_details", False), + "includeInterpolatedFrames": _params.get( + "interpolated_frames", False + ), + "projectIds": _params.get("project_ids", None), + "modelRunIds": _params.get("model_run_ids", None), + "allProjects": _params.get("all_projects", False), + "allModelRuns": _params.get("all_model_runs", False), }, "streamable": streamable, } } - res = self.client.execute(create_task_query_str, - query_params, - error_log_key="errors") + res = self.client.execute( + create_task_query_str, query_params, error_log_key="errors" + ) res = res[mutation_name] task_id = res["taskId"] is_streamable = res["isStreamable"] @@ -284,20 +291,21 @@ def get_data_row_ids(self, model_run_id: str) -> PaginatedCollection: return PaginatedCollection( client=self.client, query=ModelSlice.query_str(), - params={ - 'id': str(self.uid), - 'modelRunId': model_run_id - }, - dereferencing=['getDataRowIdentifiersBySavedModelQuery', 'nodes'], - obj_class=lambda _, data_row_id_and_gk: data_row_id_and_gk.get('id' - ), + params={"id": str(self.uid), "modelRunId": model_run_id}, + dereferencing=["getDataRowIdentifiersBySavedModelQuery", "nodes"], + obj_class=lambda _, data_row_id_and_gk: data_row_id_and_gk.get( + "id" + ), cursor_path=[ - 'getDataRowIdentifiersBySavedModelQuery', 'pageInfo', - 'endCursor' - ]) + "getDataRowIdentifiersBySavedModelQuery", + "pageInfo", + "endCursor", + ], + ) - def get_data_row_identifiers(self, - model_run_id: str) -> PaginatedCollection: + def get_data_row_identifiers( + self, model_run_id: str + ) -> PaginatedCollection: """ Fetches all data row ids and global keys (where defined) that match this Slice @@ -310,15 +318,15 @@ def get_data_row_identifiers(self, return PaginatedCollection( client=self.client, query=ModelSlice.query_str(), - params={ - 'id': str(self.uid), - 'modelRunId': model_run_id - }, - dereferencing=['getDataRowIdentifiersBySavedModelQuery', 'nodes'], + params={"id": str(self.uid), "modelRunId": model_run_id}, + dereferencing=["getDataRowIdentifiersBySavedModelQuery", "nodes"], obj_class=lambda _, data_row_id_and_gk: Slice.DataRowIdAndGlobalKey( - data_row_id_and_gk.get('id'), - data_row_id_and_gk.get('globalKey', None)), + data_row_id_and_gk.get("id"), + data_row_id_and_gk.get("globalKey", None), + ), cursor_path=[ - 'getDataRowIdentifiersBySavedModelQuery', 'pageInfo', - 'endCursor' - ]) + "getDataRowIdentifiersBySavedModelQuery", + "pageInfo", + "endCursor", + ], + ) diff --git a/libs/labelbox/src/labelbox/schema/task.py b/libs/labelbox/src/labelbox/schema/task.py index 19d27c325..9d7a26e1d 100644 --- a/libs/labelbox/src/labelbox/schema/task.py +++ b/libs/labelbox/src/labelbox/schema/task.py @@ -11,7 +11,8 @@ from labelbox.pagination import PaginatedCollection from labelbox.schema.internal.datarow_upload_constants import ( - DOWNLOAD_RESULT_PAGE_SIZE,) + DOWNLOAD_RESULT_PAGE_SIZE, +) if TYPE_CHECKING: from labelbox import User @@ -25,7 +26,7 @@ def lru_cache() -> Callable[..., Callable[..., Dict[str, Any]]]: class Task(DbObject): - """ Represents a server-side process that might take a longer time to process. + """Represents a server-side process that might take a longer time to process. Allows the Task state to be updated and checked on the client side. Attributes: @@ -38,6 +39,7 @@ class Task(DbObject): created_by (Relationship): `ToOne` relationship to User organization (Relationship): `ToOne` relationship to Organization """ + updated_at = Field.DateTime("updated_at") created_at = Field.DateTime("created_at") name = Field.String("name") @@ -54,18 +56,21 @@ class Task(DbObject): organization = Relationship.ToOne("Organization") def __eq__(self, task): - return isinstance( - task, Task) and task.uid == self.uid and task.type == self.type + return ( + isinstance(task, Task) + and task.uid == self.uid + and task.type == self.type + ) def __hash__(self): return hash(self.uid) # Import and upsert have several instances of special casing def is_creation_task(self) -> bool: - return self.name == 'JSON Import' or self.type == 'adv-upsert-data-rows' + return self.name == "JSON Import" or self.type == "adv-upsert-data-rows" def refresh(self) -> None: - """ Refreshes Task data from the server. """ + """Refreshes Task data from the server.""" assert self._user is not None tasks = list(self._user.created_tasks(where=Task.uid == self.uid)) if len(tasks) != 1: @@ -84,24 +89,25 @@ def has_errors(self) -> bool: return bool(self.failed_data_rows) return self.status == "FAILED" - def wait_until_done(self, - timeout_seconds: float = 300.0, - check_frequency: float = 2.0) -> None: + def wait_until_done( + self, timeout_seconds: float = 300.0, check_frequency: float = 2.0 + ) -> None: self.wait_till_done(timeout_seconds, check_frequency) - def wait_till_done(self, - timeout_seconds: float = 300.0, - check_frequency: float = 2.0) -> None: - """ Waits until the task is completed. Periodically queries the server - to update the task attributes. + def wait_till_done( + self, timeout_seconds: float = 300.0, check_frequency: float = 2.0 + ) -> None: + """Waits until the task is completed. Periodically queries the server + to update the task attributes. - Args: - timeout_seconds (float): Maximum time this method can block, in seconds. Defaults to five minutes. - check_frequency (float): Frequency of queries to server to update the task attributes, in seconds. Defaults to two seconds. Minimal value is two seconds. - """ + Args: + timeout_seconds (float): Maximum time this method can block, in seconds. Defaults to five minutes. + check_frequency (float): Frequency of queries to server to update the task attributes, in seconds. Defaults to two seconds. Minimal value is two seconds. + """ if check_frequency < 2.0: raise ValueError( - "Expected check frequency to be two seconds or more") + "Expected check frequency to be two seconds or more" + ) while timeout_seconds > 0: if self.status != "IN_PROGRESS": if self.has_errors(): @@ -109,16 +115,16 @@ def wait_till_done(self, "There are errors present. Please look at `task.errors` for more details" ) return - logger.debug("Task.wait_till_done sleeping for %d seconds", - check_frequency) + logger.debug( + "Task.wait_till_done sleeping for %d seconds", check_frequency + ) time.sleep(check_frequency) timeout_seconds -= check_frequency self.refresh() @property def errors(self) -> Optional[Dict[str, Any]]: - """ Fetch the error associated with an import task. - """ + """Fetch the error associated with an import task.""" if self.is_creation_task(): if self.status == "FAILED": result = self._fetch_remote_json() @@ -126,10 +132,12 @@ def errors(self) -> Optional[Dict[str, Any]]: elif self.status == "COMPLETE": return self.failed_data_rows elif self.type == "export-data-rows": - return self._fetch_remote_json(remote_json_field='errors_url') - elif (self.type == "add-data-rows-to-batch" or - self.type == "send-to-task-queue" or - self.type == "send-to-annotate"): + return self._fetch_remote_json(remote_json_field="errors_url") + elif ( + self.type == "add-data-rows-to-batch" + or self.type == "send-to-task-queue" + or self.type == "send-to-annotate" + ): if self.status == "FAILED": # for these tasks, the error is embedded in the result itself return json.loads(self.result_url) @@ -137,26 +145,27 @@ def errors(self) -> Optional[Dict[str, Any]]: @property def result(self) -> Union[List[Dict[str, Any]], Dict[str, Any]]: - """ Fetch the result for an import task. - """ + """Fetch the result for an import task.""" if self.status == "FAILED": raise ValueError(f"Job failed. Errors : {self.errors}") else: result = self._fetch_remote_json() - if self.type == 'export-data-rows': + if self.type == "export-data-rows": return result - return [{ - 'id': data_row['id'], - 'external_id': data_row.get('externalId'), - 'row_data': data_row['rowData'], - 'global_key': data_row.get('globalKey'), - } for data_row in result['createdDataRows']] + return [ + { + "id": data_row["id"], + "external_id": data_row.get("externalId"), + "row_data": data_row["rowData"], + "global_key": data_row.get("globalKey"), + } + for data_row in result["createdDataRows"] + ] @property def failed_data_rows(self) -> Optional[Dict[str, Any]]: - """ Fetch data rows which failed to be created for an import task. - """ + """Fetch data rows which failed to be created for an import task.""" result = self._fetch_remote_json() if len(result.get("errors", [])) > 0: return result["errors"] @@ -165,8 +174,7 @@ def failed_data_rows(self) -> Optional[Dict[str, Any]]: @property def created_data_rows(self) -> Optional[Dict[str, Any]]: - """ Fetch data rows which successfully created for an import task. - """ + """Fetch data rows which successfully created for an import task.""" result = self._fetch_remote_json() if len(result.get("createdDataRows", [])) > 0: return result["createdDataRows"] @@ -174,23 +182,22 @@ def created_data_rows(self) -> Optional[Dict[str, Any]]: return None @lru_cache() - def _fetch_remote_json(self, - remote_json_field: Optional[str] = None - ) -> Dict[str, Any]: - """ Function for fetching and caching the result data. - """ + def _fetch_remote_json( + self, remote_json_field: Optional[str] = None + ) -> Dict[str, Any]: + """Function for fetching and caching the result data.""" def download_result(remote_json_field: Optional[str], format: str): - url = getattr(self, remote_json_field or 'result_url') + url = getattr(self, remote_json_field or "result_url") if url is None: return None response = requests.get(url) response.raise_for_status() - if format == 'json': + if format == "json": return response.json() - elif format == 'ndjson': + elif format == "ndjson": return parser.loads(response.text) else: raise ValueError( @@ -198,9 +205,9 @@ def download_result(remote_json_field: Optional[str], format: str): ) if self.is_creation_task(): - format = 'json' - elif self.type == 'export-data-rows': - format = 'ndjson' + format = "json" + elif self.type == "export-data-rows": + format = "ndjson" else: raise ValueError( "Task result is only supported for `JSON Import` and `export` tasks." @@ -221,7 +228,8 @@ def download_result(remote_json_field: Optional[str], format: str): def get_task(client, task_id): user: User = client.get_user() tasks: List[Task] = list( - user.created_tasks(where=Entity.Task.uid == task_id)) + user.created_tasks(where=Entity.Task.uid == task_id) + ) # Cache user in a private variable as the relationship can't be # resolved due to server-side limitations (see Task.created_by) # for more info. @@ -261,12 +269,14 @@ def errors(self) -> Optional[List[Dict[str, Any]]]: # type: ignore @property def created_data_rows( # type: ignore - self) -> Optional[List[Dict[str, Any]]]: + self, + ) -> Optional[List[Dict[str, Any]]]: return self.result @property def failed_data_rows( # type: ignore - self) -> Optional[List[Dict[str, Any]]]: + self, + ) -> Optional[List[Dict[str, Any]]]: return self.errors def _download_results_paginated(self) -> PaginatedCollection: @@ -289,23 +299,23 @@ def _download_results_paginated(self) -> PaginatedCollection: """ params = { - 'taskId': self.uid, - 'first': page_size, - 'from': from_cursor, + "taskId": self.uid, + "first": page_size, + "from": from_cursor, } return PaginatedCollection( client=self.client, query=query_str, params=params, - dereferencing=['successesfulDataRowImports', 'nodes'], + dereferencing=["successesfulDataRowImports", "nodes"], obj_class=lambda _, data_row: { - 'id': data_row.get('id'), - 'external_id': data_row.get('externalId'), - 'row_data': data_row.get('rowData'), - 'global_key': data_row.get('globalKey'), + "id": data_row.get("id"), + "external_id": data_row.get("externalId"), + "row_data": data_row.get("rowData"), + "global_key": data_row.get("globalKey"), }, - cursor_path=['successesfulDataRowImports', 'after'], + cursor_path=["successesfulDataRowImports", "after"], ) def _download_errors_paginated(self) -> PaginatedCollection: @@ -340,32 +350,33 @@ def _download_errors_paginated(self) -> PaginatedCollection: """ params = { - 'taskId': self.uid, - 'first': page_size, - 'from': from_cursor, + "taskId": self.uid, + "first": page_size, + "from": from_cursor, } def convert_errors_to_legacy_format(client, data_row): - spec = data_row.get('spec', {}) + spec = data_row.get("spec", {}) return { - 'message': - data_row.get('message'), - 'failedDataRows': [{ - 'externalId': spec.get('externalId'), - 'rowData': spec.get('rowData'), - 'globalKey': spec.get('globalKey'), - 'metadata': spec.get('metadata', []), - 'attachments': spec.get('attachments', []), - }] + "message": data_row.get("message"), + "failedDataRows": [ + { + "externalId": spec.get("externalId"), + "rowData": spec.get("rowData"), + "globalKey": spec.get("globalKey"), + "metadata": spec.get("metadata", []), + "attachments": spec.get("attachments", []), + } + ], } return PaginatedCollection( client=self.client, query=query_str, params=params, - dereferencing=['failedDataRowImports', 'results'], + dereferencing=["failedDataRowImports", "results"], obj_class=convert_errors_to_legacy_format, - cursor_path=['failedDataRowImports', 'after'], + cursor_path=["failedDataRowImports", "after"], ) def _results_as_list(self) -> Optional[List[Dict[str, Any]]]: diff --git a/libs/labelbox/src/labelbox/schema/user.py b/libs/labelbox/src/labelbox/schema/user.py index 430868b85..f7b3cd0d6 100644 --- a/libs/labelbox/src/labelbox/schema/user.py +++ b/libs/labelbox/src/labelbox/schema/user.py @@ -7,7 +7,7 @@ class User(DbObject): - """ A User is a registered Labelbox user (for example you) associated with + """A User is a registered Labelbox user (for example you) associated with data they create or import and an Organization they belong to. Attributes: @@ -43,7 +43,7 @@ class User(DbObject): org_role = Relationship.ToOne("OrgRole", False) def update_org_role(self, role: "Role") -> None: - """ Updated the `User`s organization role. + """Updated the `User`s organization role. See client.get_roles() to get all valid roles If you a user is converted from project level permissions to org level permissions and then convert back, their permissions will remain for each individual project @@ -58,23 +58,22 @@ def update_org_role(self, role: "Role") -> None: setOrganizationRole(data: {userId: $userId, roleId: $roleId}) { id name }} """ % (user_id_param, role_id_param) - self.client.execute(query_str, { - user_id_param: self.uid, - role_id_param: role.uid - }) + self.client.execute( + query_str, {user_id_param: self.uid, role_id_param: role.uid} + ) def remove_from_project(self, project: "Project") -> None: - """ Removes a User from a project. Only used for project based users. + """Removes a User from a project. Only used for project based users. Project based user means their org role is "NONE" Args: project (Project): Project to remove user from """ - self.upsert_project_role(project, self.client.get_roles()['NONE']) + self.upsert_project_role(project, self.client.get_roles()["NONE"]) def upsert_project_role(self, project: "Project", role: "Role") -> None: - """ Updates or replaces a User's role in a project. + """Updates or replaces a User's role in a project. Args: project (Project): The project to update the users permissions for @@ -82,21 +81,30 @@ def upsert_project_role(self, project: "Project", role: "Role") -> None: """ org_role = self.org_role() - if org_role.name.upper() != 'NONE': + if org_role.name.upper() != "NONE": raise ValueError( - "User is not project based and has access to all projects") + "User is not project based and has access to all projects" + ) project_id_param = "projectId" user_id_param = "userId" role_id_param = "roleId" query_str = """mutation SetProjectMembershipPyApi($%s: ID!, $%s: ID!, $%s: ID!) { setProjectMembership(data: {%s: $userId, roleId: $%s, projectId: $%s}) {id}} - """ % (user_id_param, role_id_param, project_id_param, user_id_param, - role_id_param, project_id_param) + """ % ( + user_id_param, + role_id_param, + project_id_param, + user_id_param, + role_id_param, + project_id_param, + ) self.client.execute( - query_str, { + query_str, + { project_id_param: project.uid, user_id_param: self.uid, - role_id_param: role.uid - }) + role_id_param: role.uid, + }, + ) diff --git a/libs/labelbox/src/labelbox/schema/user_group.py b/libs/labelbox/src/labelbox/schema/user_group.py index 91cdb159c..9d506bf92 100644 --- a/libs/labelbox/src/labelbox/schema/user_group.py +++ b/libs/labelbox/src/labelbox/schema/user_group.py @@ -6,7 +6,11 @@ from labelbox.exceptions import ResourceCreationError from labelbox.schema.user import User from labelbox.schema.project import Project -from labelbox.exceptions import UnprocessableEntityError, MalformedQueryException, ResourceNotFoundError +from labelbox.exceptions import ( + UnprocessableEntityError, + MalformedQueryException, + ResourceNotFoundError, +) from labelbox.schema.queue_mode import QueueMode from labelbox.schema.ontology_kind import EditorTaskType from labelbox.schema.media_type import MediaType @@ -28,6 +32,7 @@ class UserGroupColor(Enum): YELLOW (str): Hex color code for yellow (#E7BF00). GRAY (str): Hex color code for gray (#B8C4D3). """ + BLUE = "9EC5FF" PURPLE = "CEB8FF" ORANGE = "FFB35F" @@ -38,7 +43,7 @@ class UserGroupColor(Enum): YELLOW = "E7BF00" GRAY = "B8C4D3" - + class UserGroup(BaseModel): """ Represents a user group in Labelbox. @@ -59,14 +64,14 @@ class UserGroup(BaseModel): delete(self) -> bool get_user_groups(client: Client) -> Iterator["UserGroup"] """ + id: str name: str color: UserGroupColor users: Set[User] projects: Set[Project] client: Client - model_config = ConfigDict(arbitrary_types_allowed = True) - + model_config = ConfigDict(arbitrary_types_allowed=True) def __init__( self, @@ -75,7 +80,7 @@ def __init__( name: str = "", color: UserGroupColor = UserGroupColor.BLUE, users: Set[User] = set(), - projects: Set[Project] = set() + projects: Set[Project] = set(), ): """ Initializes a UserGroup object. @@ -91,9 +96,18 @@ def __init__( Raises: RuntimeError: If the experimental feature is not enabled in the client. """ - super().__init__(client=client, id=id, name=name, color=color, users=users, projects=projects) + super().__init__( + client=client, + id=id, + name=name, + color=color, + users=users, + projects=projects, + ) if not self.client.enable_experimental: - raise RuntimeError("Please enable experimental in client to use UserGroups") + raise RuntimeError( + "Please enable experimental in client to use UserGroups" + ) def get(self) -> "UserGroup": """ @@ -140,11 +154,17 @@ def get(self) -> "UserGroup": } result = self.client.execute(query, params) if not result: - raise ResourceNotFoundError(message="Failed to get user group as user group does not exist") + raise ResourceNotFoundError( + message="Failed to get user group as user group does not exist" + ) self.name = result["userGroup"]["name"] self.color = UserGroupColor(result["userGroup"]["color"]) - self.projects = self._get_projects_set(result["userGroup"]["projects"]["nodes"]) - self.users = self._get_users_set(result["userGroup"]["members"]["nodes"]) + self.projects = self._get_projects_set( + result["userGroup"]["projects"]["nodes"] + ) + self.users = self._get_users_set( + result["userGroup"]["members"]["nodes"] + ) return self def update(self) -> "UserGroup": @@ -190,23 +210,18 @@ def update(self) -> "UserGroup": } """ params = { - "id": - self.id, - "name": - self.name, - "color": - self.color.value, - "projectIds": [ - project.uid for project in self.projects - ], - "userIds": [ - user.uid for user in self.users - ] + "id": self.id, + "name": self.name, + "color": self.color.value, + "projectIds": [project.uid for project in self.projects], + "userIds": [user.uid for user in self.users], } try: result = self.client.execute(query, params) if not result: - raise ResourceNotFoundError(message="Failed to update user group as user group does not exist") + raise ResourceNotFoundError( + message="Failed to update user group as user group does not exist" + ) except MalformedQueryException as e: raise UnprocessableEntityError("Failed to update user group") from e return self @@ -257,26 +272,22 @@ def create(self) -> "UserGroup": } """ params = { - "name": - self.name, - "color": - self.color.value, - "projectIds": [ - project.uid for project in self.projects - ], - "userIds": [ - user.uid for user in self.users - ] + "name": self.name, + "color": self.color.value, + "projectIds": [project.uid for project in self.projects], + "userIds": [user.uid for user in self.users], } result = None error = None - try: + try: result = self.client.execute(query, params) except Exception as e: error = e if not result or error: # this is client side only, server doesn't have an equivalent error - raise ResourceCreationError(f"Failed to create user group, either user group name is in use currently, or provided user or projects don't exist server error: {error}") + raise ResourceCreationError( + f"Failed to create user group, either user group name is in use currently, or provided user or projects don't exist server error: {error}" + ) result = result["createUserGroup"]["group"] self.id = result["id"] return self @@ -291,7 +302,7 @@ def delete(self) -> bool: Returns: bool: True if the user group was successfully deleted, False otherwise. - + Raises: ResourceNotFoundError: If the deletion of the user group fails due to not existing ValueError: If the group ID is not provided. @@ -308,7 +319,9 @@ def delete(self) -> bool: params = {"id": self.id} result = self.client.execute(query, params) if not result: - raise ResourceNotFoundError(message="Failed to delete user group as user group does not exist") + raise ResourceNotFoundError( + message="Failed to delete user group as user group does not exist" + ) return result["deleteUserGroup"]["success"] def get_user_groups(self) -> Iterator["UserGroup"]: @@ -349,8 +362,9 @@ def get_user_groups(self) -> Iterator["UserGroup"]: """ nextCursor = None while True: - userGroups = self.client.execute( - query, {"after": nextCursor})["userGroups"] + userGroups = self.client.execute(query, {"after": nextCursor})[ + "userGroups" + ] if not userGroups: return yield @@ -361,7 +375,9 @@ def get_user_groups(self) -> Iterator["UserGroup"]: userGroup.name = group["name"] userGroup.color = UserGroupColor(group["color"]) userGroup.users = self._get_users_set(group["members"]["nodes"]) - userGroup.projects = self._get_projects_set(group["projects"]["nodes"]) + userGroup.projects = self._get_projects_set( + group["projects"]["nodes"] + ) yield userGroup nextCursor = userGroups["nextCursor"] if not nextCursor: diff --git a/libs/labelbox/src/labelbox/schema/webhook.py b/libs/labelbox/src/labelbox/schema/webhook.py index 1f1653c52..0eebe157e 100644 --- a/libs/labelbox/src/labelbox/schema/webhook.py +++ b/libs/labelbox/src/labelbox/schema/webhook.py @@ -10,7 +10,7 @@ class Webhook(DbObject, Updateable): - """ Represents a server-side rule for sending notifications to a web-server + """Represents a server-side rule for sending notifications to a web-server whenever one of several predefined actions happens within a context of a Project or an Organization. @@ -53,7 +53,7 @@ class Topic(Enum): @staticmethod def create(client, topics, url, secret, project) -> "Webhook": - """ Creates a Webhook. + """Creates a Webhook. Args: client (Client): The Labelbox client used to connect @@ -84,13 +84,19 @@ def create(client, topics, url, secret, project) -> "Webhook": raise ValueError("URL must be a non-empty string.") Webhook.validate_topics(topics) - project_str = "" if project is None \ - else ("project:{id:\"%s\"}," % project.uid) + project_str = ( + "" if project is None else ('project:{id:"%s"},' % project.uid) + ) query_str = """mutation CreateWebhookPyApi { createWebhook(data:{%s topics:{set:[%s]}, url:"%s", secret:"%s" }){%s} - } """ % (project_str, " ".join(topics), url, secret, - query.results_query_part(Entity.Webhook)) + } """ % ( + project_str, + " ".join(topics), + url, + secret, + query.results_query_part(Entity.Webhook), + ) return Webhook(client, client.execute(query_str)["createWebhook"]) @@ -98,7 +104,8 @@ def create(client, topics, url, secret, project) -> "Webhook": def validate_topics(topics) -> None: if isinstance(topics, str) or not isinstance(topics, Iterable): raise TypeError( - f"Topics must be List[Webhook.Topic]. Found `{topics}`") + f"Topics must be List[Webhook.Topic]. Found `{topics}`" + ) for topic in topics: Webhook.validate_value(topic, Webhook.Topic) @@ -118,7 +125,7 @@ def delete(self) -> None: self.update(status=self.Status.INACTIVE.value) def update(self, topics=None, url=None, status=None): - """ Updates the Webhook. + """Updates the Webhook. Args: topics (Optional[List[Topic]]): The new topics. @@ -137,15 +144,17 @@ def update(self, topics=None, url=None, status=None): if status is not None: self.validate_value(status, self.Status) - topics_str = "" if topics is None \ - else "topics: {set: [%s]}" % " ".join(topics) - url_str = "" if url is None else "url: \"%s\"" % url + topics_str = ( + "" if topics is None else "topics: {set: [%s]}" % " ".join(topics) + ) + url_str = "" if url is None else 'url: "%s"' % url status_str = "" if status is None else "status: %s" % status query_str = """mutation UpdateWebhookPyApi { updateWebhook(where: {id: "%s"} data:{%s}){%s}} """ % ( - self.uid, ", ".join(filter(None, - (topics_str, url_str, status_str))), - query.results_query_part(Entity.Webhook)) + self.uid, + ", ".join(filter(None, (topics_str, url_str, status_str))), + query.results_query_part(Entity.Webhook), + ) self._set_field_values(self.client.execute(query_str)["updateWebhook"]) diff --git a/libs/labelbox/src/labelbox/types.py b/libs/labelbox/src/labelbox/types.py index 98f7042ae..0c0c2904f 100644 --- a/libs/labelbox/src/labelbox/types.py +++ b/libs/labelbox/src/labelbox/types.py @@ -3,4 +3,4 @@ except ImportError: raise ImportError( "There are missing dependencies for `labelbox.types`, use `pip install labelbox[data] --upgrade` to install missing dependencies." - ) \ No newline at end of file + ) diff --git a/libs/labelbox/src/labelbox/typing_imports.py b/libs/labelbox/src/labelbox/typing_imports.py index 2c2716710..6edfb9bef 100644 --- a/libs/labelbox/src/labelbox/typing_imports.py +++ b/libs/labelbox/src/labelbox/typing_imports.py @@ -1,10 +1,11 @@ """ -This module imports types that differ across python versions, so other modules +This module imports types that differ across python versions, so other modules don't have to worry about where they should be imported from. """ import sys + if sys.version_info >= (3, 8): from typing import Literal else: - from typing_extensions import Literal \ No newline at end of file + from typing_extensions import Literal diff --git a/libs/labelbox/src/labelbox/utils.py b/libs/labelbox/src/labelbox/utils.py index 21f0c338b..c76ce188f 100644 --- a/libs/labelbox/src/labelbox/utils.py +++ b/libs/labelbox/src/labelbox/utils.py @@ -6,11 +6,17 @@ from dateutil.utils import default_tzinfo from urllib.parse import urlparse -from pydantic import BaseModel, ConfigDict, model_serializer, AliasGenerator, AliasChoices +from pydantic import ( + BaseModel, + ConfigDict, + model_serializer, + AliasGenerator, + AliasChoices, +) from pydantic.alias_generators import to_camel, to_pascal -UPPERCASE_COMPONENTS = ['uri', 'rgb'] -ISO_DATETIME_FORMAT = '%Y-%m-%dT%H:%M:%SZ' +UPPERCASE_COMPONENTS = ["uri", "rgb"] +ISO_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ" DFLT_TZ = tzoffset("UTC", 0000) @@ -26,22 +32,22 @@ def _convert(s, sep, title): def camel_case(s): - """ Converts a string in [snake|camel|title]case to camelCase. """ + """Converts a string in [snake|camel|title]case to camelCase.""" return _convert(s, "", lambda i: i > 0) def title_case(s): - """ Converts a string in [snake|camel|title]case to TitleCase. """ + """Converts a string in [snake|camel|title]case to TitleCase.""" return _convert(s, "", lambda i: True) def snake_case(s): - """ Converts a string in [snake|camel|title]case to snake_case. """ + """Converts a string in [snake|camel|title]case to snake_case.""" return _convert(s, "_", lambda i: False) def sentence_case(s: str) -> str: - """ Converts a string in [snake|camel|title]case to Sentence case. """ + """Converts a string in [snake|camel|title]case to Sentence case.""" # Replace underscores with spaces and convert to lower case sentence_str = s.replace("_", " ").lower() # Capitalize the first letter of each word @@ -62,7 +68,11 @@ def is_valid_uri(uri): class _CamelCaseMixin(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed = True, alias_generator = to_camel, populate_by_name = True) + model_config = ConfigDict( + arbitrary_types_allowed=True, + alias_generator=to_camel, + populate_by_name=True, + ) class _NoCoercionMixin: @@ -72,7 +82,7 @@ class _NoCoercionMixin: uninteded behavior. This mixin uses a class_name discriminator field to prevent pydantic from - corecing the type of the object. Add a class_name field to the class you + corecing the type of the object. Add a class_name field to the class you want to discrimniate and use this mixin class to remove the discriminator when serializing the object. @@ -81,10 +91,11 @@ class ConversationData(BaseData, _NoCoercionMixin): class_name: Literal["ConversationData"] = "ConversationData" """ + @model_serializer(mode="wrap") def serialize_model(self, handler): res = handler(self) - res.pop('class_name') + res.pop("class_name") return res diff --git a/libs/labelbox/tests/conftest.py b/libs/labelbox/tests/conftest.py index 4251ac698..446db396b 100644 --- a/libs/labelbox/tests/conftest.py +++ b/libs/labelbox/tests/conftest.py @@ -53,29 +53,30 @@ @pytest.fixture(scope="session") def rand_gen(): - def gen(field_type): if field_type is str: - return "".join(ascii_letters[randint(0, - len(ascii_letters) - 1)] - for _ in range(16)) + return "".join( + ascii_letters[randint(0, len(ascii_letters) - 1)] + for _ in range(16) + ) if field_type is datetime: return datetime.now() - raise Exception("Can't random generate for field type '%r'" % - field_type) + raise Exception( + "Can't random generate for field type '%r'" % field_type + ) return gen class Environ(Enum): - LOCAL = 'local' - PROD = 'prod' - STAGING = 'staging' - CUSTOM = 'custom' - STAGING_EU = 'staging-eu' - EPHEMERAL = 'ephemeral' # Used for testing PRs with ephemeral environments + LOCAL = "local" + PROD = "prod" + STAGING = "staging" + CUSTOM = "custom" + STAGING_EU = "staging-eu" + EPHEMERAL = "ephemeral" # Used for testing PRs with ephemeral environments @pytest.fixture @@ -89,48 +90,50 @@ def external_id() -> str: def ephemeral_endpoint() -> str: - return os.getenv('LABELBOX_TEST_BASE_URL', EPHEMERAL_BASE_URL) + return os.getenv("LABELBOX_TEST_BASE_URL", EPHEMERAL_BASE_URL) def graphql_url(environ: str) -> str: if environ == Environ.LOCAL: - return 'http://localhost:3000/api/graphql' + return "http://localhost:3000/api/graphql" elif environ == Environ.PROD: - return 'https://api.labelbox.com/graphql' + return "https://api.labelbox.com/graphql" elif environ == Environ.STAGING: - return 'https://api.lb-stage.xyz/graphql' + return "https://api.lb-stage.xyz/graphql" elif environ == Environ.CUSTOM: graphql_api_endpoint = os.environ.get( - 'LABELBOX_TEST_GRAPHQL_API_ENDPOINT') + "LABELBOX_TEST_GRAPHQL_API_ENDPOINT" + ) if graphql_api_endpoint is None: raise Exception("Missing LABELBOX_TEST_GRAPHQL_API_ENDPOINT") return graphql_api_endpoint elif environ == Environ.EPHEMERAL: return f"{ephemeral_endpoint()}/graphql" - return 'http://host.docker.internal:8080/graphql' + return "http://host.docker.internal:8080/graphql" def rest_url(environ: str) -> str: if environ == Environ.LOCAL: - return 'http://localhost:3000/api/v1' + return "http://localhost:3000/api/v1" elif environ == Environ.PROD: - return 'https://api.labelbox.com/api/v1' + return "https://api.labelbox.com/api/v1" elif environ == Environ.STAGING: - return 'https://api.lb-stage.xyz/api/v1' + return "https://api.lb-stage.xyz/api/v1" elif environ == Environ.CUSTOM: - rest_api_endpoint = os.environ.get('LABELBOX_TEST_REST_API_ENDPOINT') + rest_api_endpoint = os.environ.get("LABELBOX_TEST_REST_API_ENDPOINT") if rest_api_endpoint is None: raise Exception("Missing LABELBOX_TEST_REST_API_ENDPOINT") return rest_api_endpoint elif environ == Environ.EPHEMERAL: return f"{ephemeral_endpoint()}/api/v1" - return 'http://host.docker.internal:8080/api/v1' + return "http://host.docker.internal:8080/api/v1" def testing_api_key(environ: Environ) -> str: keys = [ f"LABELBOX_TEST_API_KEY_{environ.value.upper()}", - "LABELBOX_TEST_API_KEY", "LABELBOX_API_KEY" + "LABELBOX_TEST_API_KEY", + "LABELBOX_API_KEY", ] for key in keys: value = os.environ.get(key) @@ -143,47 +146,51 @@ def service_api_key() -> str: service_api_key = os.environ["SERVICE_API_KEY"] if service_api_key is None: raise Exception( - "SERVICE_API_KEY is missing and needed for admin client") + "SERVICE_API_KEY is missing and needed for admin client" + ) return service_api_key class IntegrationClient(Client): - def __init__(self, environ: str) -> None: api_url = graphql_url(environ) api_key = testing_api_key(environ) rest_endpoint = rest_url(environ) - super().__init__(api_key, - api_url, - enable_experimental=True, - rest_endpoint=rest_endpoint) + super().__init__( + api_key, + api_url, + enable_experimental=True, + rest_endpoint=rest_endpoint, + ) self.queries = [] def execute(self, query=None, params=None, check_naming=True, **kwargs): if check_naming and query is not None: - assert re.match(r"\s*(?:query|mutation) \w+PyApi", - query) is not None + assert ( + re.match(r"\s*(?:query|mutation) \w+PyApi", query) is not None + ) self.queries.append((query, params)) - if not kwargs.get('timeout'): - kwargs['timeout'] = 30.0 + if not kwargs.get("timeout"): + kwargs["timeout"] = 30.0 return super().execute(query, params, **kwargs) class AdminClient(Client): - def __init__(self, env): """ - The admin client creates organizations and users using admin api described here https://labelbox.atlassian.net/wiki/spaces/AP/pages/2206564433/Internal+Admin+APIs. + The admin client creates organizations and users using admin api described here https://labelbox.atlassian.net/wiki/spaces/AP/pages/2206564433/Internal+Admin+APIs. """ self._api_key = service_api_key() self._admin_endpoint = f"{ephemeral_endpoint()}/admin/v1" self._api_url = graphql_url(env) self._rest_endpoint = rest_url(env) - super().__init__(self._api_key, - self._api_url, - enable_experimental=True, - rest_endpoint=self._rest_endpoint) + super().__init__( + self._api_key, + self._api_url, + enable_experimental=True, + rest_endpoint=self._rest_endpoint, + ) def _create_organization(self) -> str: endpoint = f"{self._admin_endpoint}/organizations/" @@ -195,12 +202,14 @@ def _create_organization(self) -> str: data = response.json() if response.status_code not in [ - requests.codes.created, requests.codes.ok + requests.codes.created, + requests.codes.ok, ]: - raise Exception("Failed to create org, message: " + - str(data['message'])) + raise Exception( + "Failed to create org, message: " + str(data["message"]) + ) - return data['id'] + return data["id"] def _create_user(self, organization_id=None) -> Tuple[str, str]: if organization_id is None: @@ -221,31 +230,35 @@ def _create_user(self, organization_id=None) -> Tuple[str, str]: ) data = response.json() if response.status_code not in [ - requests.codes.created, requests.codes.ok + requests.codes.created, + requests.codes.ok, ]: - raise Exception("Failed to create user, message: " + - str(data['message'])) + raise Exception( + "Failed to create user, message: " + str(data["message"]) + ) - user_identity_id = data['identityId'] + user_identity_id = data["identityId"] - endpoint = f"{self._admin_endpoint}/organizations/{organization_id}/users/" + endpoint = ( + f"{self._admin_endpoint}/organizations/{organization_id}/users/" + ) response = requests.post( endpoint, headers=self.headers, - json={ - "identityId": user_identity_id, - "organizationRole": "Admin" - }, + json={"identityId": user_identity_id, "organizationRole": "Admin"}, ) data = response.json() if response.status_code not in [ - requests.codes.created, requests.codes.ok + requests.codes.created, + requests.codes.ok, ]: - raise Exception("Failed to create link user to org, message: " + - str(data['message'])) + raise Exception( + "Failed to create link user to org, message: " + + str(data["message"]) + ) - user_id = data['id'] + user_id = data["id"] endpoint = f"{self._admin_endpoint}/users/{user_id}/token" response = requests.get( @@ -254,10 +267,13 @@ def _create_user(self, organization_id=None) -> Tuple[str, str]: ) data = response.json() if response.status_code not in [ - requests.codes.created, requests.codes.ok + requests.codes.created, + requests.codes.ok, ]: - raise Exception("Failed to create ephemeral user, message: " + - str(data['message'])) + raise Exception( + "Failed to create ephemeral user, message: " + + str(data["message"]) + ) token = data["token"] @@ -282,17 +298,18 @@ def create_api_key_for_user(self) -> str: class EphemeralClient(Client): - def __init__(self, environ=Environ.EPHEMERAL): self.admin_client = AdminClient(environ) self.api_key = self.admin_client.create_api_key_for_user() api_url = graphql_url(environ) rest_endpoint = rest_url(environ) - super().__init__(self.api_key, - api_url, - enable_experimental=True, - rest_endpoint=rest_endpoint) + super().__init__( + self.api_key, + api_url, + enable_experimental=True, + rest_endpoint=rest_endpoint, + ) @pytest.fixture @@ -322,7 +339,7 @@ def environ() -> Environ: value = os.environ.get(key) if value is not None: return Environ(value) - raise Exception(f'Missing env key in: {os.environ}') + raise Exception(f"Missing env key in: {os.environ}") def cancel_invite(client, invite_id): @@ -331,7 +348,7 @@ def cancel_invite(client, invite_id): """ query_str = """mutation CancelInvitePyApi($where: WhereUniqueIdInput!) { cancelInvite(where: $where) {id}}""" - client.execute(query_str, {'where': {'id': invite_id}}, experimental=True) + client.execute(query_str, {"where": {"id": invite_id}}, experimental=True) def get_project_invites(client, project_id): @@ -344,11 +361,14 @@ def get_project_invites(client, project_id): invites(from: $from, first: $first) { nodes { %s projectInvites { projectId projectRoleName } } nextCursor}}} """ % (id_param, id_param, query.results_query_part(Invite)) - return PaginatedCollection(client, - query_str, {id_param: project_id}, - ['project', 'invites', 'nodes'], - Invite, - cursor_path=['project', 'invites', 'nextCursor']) + return PaginatedCollection( + client, + query_str, + {id_param: project_id}, + ["project", "invites", "nodes"], + Invite, + cursor_path=["project", "invites", "nextCursor"], + ) def get_invites(client): @@ -360,18 +380,23 @@ def get_invites(client): nodes { id createdAt organizationRoleName inviteeEmail } nextCursor }}}""" invites = PaginatedCollection( client, - query_str, {}, ['organization', 'invites', 'nodes'], + query_str, + {}, + ["organization", "invites", "nodes"], Invite, - cursor_path=['organization', 'invites', 'nextCursor'], - experimental=True) + cursor_path=["organization", "invites", "nextCursor"], + experimental=True, + ) return invites @pytest.fixture def queries(): - return SimpleNamespace(cancel_invite=cancel_invite, - get_project_invites=get_project_invites, - get_invites=get_invites) + return SimpleNamespace( + cancel_invite=cancel_invite, + get_project_invites=get_project_invites, + get_invites=get_invites, + ) @pytest.fixture(scope="session") @@ -388,52 +413,57 @@ def client(environ: str): @pytest.fixture(scope="session") def pdf_url(client): - pdf_url = client.upload_file('tests/assets/loremipsum.pdf') - return {"row_data": {"pdf_url": pdf_url,}, "global_key": str(uuid.uuid4())} + pdf_url = client.upload_file("tests/assets/loremipsum.pdf") + return { + "row_data": { + "pdf_url": pdf_url, + }, + "global_key": str(uuid.uuid4()), + } @pytest.fixture(scope="session") def pdf_entity_data_row(client): pdf_url = client.upload_file( - 'tests/assets/arxiv-pdf_data_99-word-token-pdfs_0801.3483.pdf') + "tests/assets/arxiv-pdf_data_99-word-token-pdfs_0801.3483.pdf" + ) text_layer_url = client.upload_file( - 'tests/assets/arxiv-pdf_data_99-word-token-pdfs_0801.3483-lb-textlayer.json' + "tests/assets/arxiv-pdf_data_99-word-token-pdfs_0801.3483-lb-textlayer.json" ) return { - "row_data": { - "pdf_url": pdf_url, - "text_layer_url": text_layer_url - }, - "global_key": str(uuid.uuid4()) + "row_data": {"pdf_url": pdf_url, "text_layer_url": text_layer_url}, + "global_key": str(uuid.uuid4()), } @pytest.fixture() def conversation_entity_data_row(client, rand_gen): return { - "row_data": - "https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json", - "global_key": - f"https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json-{rand_gen(str)}", + "row_data": "https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json", + "global_key": f"https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json-{rand_gen(str)}", } @pytest.fixture def project(client, rand_gen): - project = client.create_project(name=rand_gen(str), - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) + project = client.create_project( + name=rand_gen(str), + queue_mode=QueueMode.Batch, + media_type=MediaType.Image, + ) yield project project.delete() @pytest.fixture def consensus_project(client, rand_gen): - project = client.create_project(name=rand_gen(str), - quality_mode=QualityMode.Consensus, - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) + project = client.create_project( + name=rand_gen(str), + quality_mode=QualityMode.Consensus, + queue_mode=QueueMode.Batch, + media_type=MediaType.Image, + ) yield project project.delete() @@ -443,23 +473,24 @@ def model_config(client, rand_gen, valid_model_id): model_config = client.create_model_config( name=rand_gen(str), model_id=valid_model_id, - inference_params={"param": "value"}) + inference_params={"param": "value"}, + ) yield model_config client.delete_model_config(model_config.uid) @pytest.fixture -def consensus_project_with_batch(consensus_project, initial_dataset, rand_gen, - image_url): +def consensus_project_with_batch( + consensus_project, initial_dataset, rand_gen, image_url +): project = consensus_project dataset = initial_dataset data_rows = [] for _ in range(3): - data_rows.append({ - DataRow.row_data: image_url, - DataRow.global_key: str(uuid.uuid4()) - }) + data_rows.append( + {DataRow.row_data: image_url, DataRow.global_key: str(uuid.uuid4())} + ) task = dataset.create_data_rows(data_rows) task.wait_till_done() assert task.status == "COMPLETE" @@ -469,7 +500,7 @@ def consensus_project_with_batch(consensus_project, initial_dataset, rand_gen, batch = project.create_batch( rand_gen(str), data_rows, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) + 5, # priority between 1(Highest) - 5(lowest) ) yield [project, batch, data_rows] @@ -483,7 +514,7 @@ def dataset(client, rand_gen): dataset.delete() -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def unique_dataset(client, rand_gen): dataset = client.create_dataset(name=rand_gen(str)) yield dataset @@ -492,12 +523,12 @@ def unique_dataset(client, rand_gen): @pytest.fixture def small_dataset(dataset: Dataset): - task = dataset.create_data_rows([ - { - "row_data": SMALL_DATASET_URL, - "external_id": "my-image" - }, - ] * 2) + task = dataset.create_data_rows( + [ + {"row_data": SMALL_DATASET_URL, "external_id": "my-image"}, + ] + * 2 + ) task.wait_till_done() yield dataset @@ -506,13 +537,15 @@ def small_dataset(dataset: Dataset): @pytest.fixture def data_row(dataset, image_url, rand_gen): global_key = f"global-key-{rand_gen(str)}" - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image", - "global_key": global_key - }, - ]) + task = dataset.create_data_rows( + [ + { + "row_data": image_url, + "external_id": "my-image", + "global_key": global_key, + }, + ] + ) task.wait_till_done() dr = dataset.data_rows().get_one() yield dr @@ -522,13 +555,15 @@ def data_row(dataset, image_url, rand_gen): @pytest.fixture def data_row_and_global_key(dataset, image_url, rand_gen): global_key = f"global-key-{rand_gen(str)}" - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image", - "global_key": global_key - }, - ]) + task = dataset.create_data_rows( + [ + { + "row_data": image_url, + "external_id": "my-image", + "global_key": global_key, + }, + ] + ) task.wait_till_done() dr = dataset.data_rows().get_one() yield dr, global_key @@ -539,10 +574,11 @@ def data_row_and_global_key(dataset, image_url, rand_gen): # @pytest.mark.parametrize('data_rows', [], indirect=True) # if omitted, count defaults to 1 @pytest.fixture -def data_rows(dataset, image_url, request, wait_for_data_row_processing, - client): +def data_rows( + dataset, image_url, request, wait_for_data_row_processing, client +): count = 1 - if hasattr(request, 'param'): + if hasattr(request, "param"): count = request.param datarows = [ @@ -565,26 +601,26 @@ def data_rows(dataset, image_url, request, wait_for_data_row_processing, @pytest.fixture def iframe_url(environ) -> str: if environ in [Environ.PROD, Environ.LOCAL]: - return 'https://editor.labelbox.com' + return "https://editor.labelbox.com" elif environ == Environ.STAGING: - return 'https://editor.lb-stage.xyz' + return "https://editor.lb-stage.xyz" @pytest.fixture def sample_image() -> str: - path_to_video = 'tests/integration/media/sample_image.jpg' + path_to_video = "tests/integration/media/sample_image.jpg" return path_to_video @pytest.fixture def sample_video() -> str: - path_to_video = 'tests/integration/media/cat.mp4' + path_to_video = "tests/integration/media/cat.mp4" return path_to_video @pytest.fixture def sample_bulk_conversation() -> list: - path_to_conversation = 'tests/integration/media/bulk_conversation.json' + path_to_conversation = "tests/integration/media/bulk_conversation.json" with open(path_to_conversation) as json_file: conversations = json.load(json_file) return conversations @@ -599,8 +635,15 @@ def organization(client): @pytest.fixture -def configured_project_with_label(client, rand_gen, image_url, project, dataset, - data_row, wait_for_label_processing): +def configured_project_with_label( + client, + rand_gen, + image_url, + project, + dataset, + data_row, + wait_for_label_processing, +): """Project with a connected dataset, having one datarow Project contains an ontology with 1 bbox tool Additionally includes a create_label method for any needed extra labels @@ -609,16 +652,18 @@ def configured_project_with_label(client, rand_gen, image_url, project, dataset, project._wait_until_data_rows_are_processed( data_row_ids=[data_row.uid], wait_processing_max_seconds=DATA_ROW_PROCESSING_WAIT_TIMEOUT_SECONDS, - sleep_interval=DATA_ROW_PROCESSING_WAIT_SLEEP_INTERNAL_SECONDS) + sleep_interval=DATA_ROW_PROCESSING_WAIT_SLEEP_INTERNAL_SECONDS, + ) project.create_batch( rand_gen(str), [data_row.uid], # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) + 5, # priority between 1(Highest) - 5(lowest) ) ontology = _setup_ontology(project) - label = _create_label(project, data_row, ontology, - wait_for_label_processing) + label = _create_label( + project, data_row, ontology, wait_for_label_processing + ) yield [project, dataset, data_row, label] for label in project.labels(): @@ -626,32 +671,32 @@ def configured_project_with_label(client, rand_gen, image_url, project, dataset, def _create_label(project, data_row, ontology, wait_for_label_processing): - predictions = [{ - "uuid": str(uuid.uuid4()), - "schemaId": ontology.tools[0].feature_schema_id, - "dataRow": { - "id": data_row.uid - }, - "bbox": { - "top": 20, - "left": 20, - "height": 50, - "width": 50 + predictions = [ + { + "uuid": str(uuid.uuid4()), + "schemaId": ontology.tools[0].feature_schema_id, + "dataRow": {"id": data_row.uid}, + "bbox": {"top": 20, "left": 20, "height": 50, "width": 50}, } - }] + ] def create_label(): - """ Ad-hoc function to create a LabelImport + """Ad-hoc function to create a LabelImport Creates a LabelImport task which will create a label """ upload_task = LabelImport.create_from_objects( - project.client, project.uid, f'label-import-{uuid.uuid4()}', - predictions) + project.client, + project.uid, + f"label-import-{uuid.uuid4()}", + predictions, + ) upload_task.wait_until_done(sleep_time_seconds=5) - assert upload_task.state == AnnotationImportState.FINISHED, "Label Import did not finish" - assert len( - upload_task.errors - ) == 0, f"Label Import {upload_task.name} failed with errors {upload_task.errors}" + assert ( + upload_task.state == AnnotationImportState.FINISHED + ), "Label Import did not finish" + assert ( + len(upload_task.errors) == 0 + ), f"Label Import {upload_task.name} failed with errors {upload_task.errors}" project.create_label = create_label project.create_label() @@ -662,10 +707,14 @@ def create_label(): def _setup_ontology(project): editor = list( project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - ontology_builder = OntologyBuilder(tools=[ - Tool(tool=Tool.Type.BBOX, name="test-bbox-class"), - ]) + where=LabelingFrontend.name == "editor" + ) + )[0] + ontology_builder = OntologyBuilder( + tools=[ + Tool(tool=Tool.Type.BBOX, name="test-bbox-class"), + ] + ) project.setup(editor, ontology_builder.asdict()) # TODO: ontology may not be synchronous after setup. remove sleep when api is more consistent time.sleep(2) @@ -674,34 +723,37 @@ def _setup_ontology(project): @pytest.fixture def big_dataset(dataset: Dataset): - task = dataset.create_data_rows([ - { - "row_data": IMAGE_URL, - "external_id": EXTERNAL_ID - }, - ] * 3) + task = dataset.create_data_rows( + [ + {"row_data": IMAGE_URL, "external_id": EXTERNAL_ID}, + ] + * 3 + ) task.wait_till_done() yield dataset @pytest.fixture -def configured_batch_project_with_label(project, dataset, data_row, - wait_for_label_processing): +def configured_batch_project_with_label( + project, dataset, data_row, wait_for_label_processing +): """Project with a batch having one datarow Project contains an ontology with 1 bbox tool Additionally includes a create_label method for any needed extra labels One label is already created and yielded when using fixture """ data_rows = [dr.uid for dr in list(dataset.data_rows())] - project._wait_until_data_rows_are_processed(data_row_ids=data_rows, - sleep_interval=3) + project._wait_until_data_rows_are_processed( + data_row_ids=data_rows, sleep_interval=3 + ) project.create_batch("test-batch", data_rows) project.data_row_ids = data_rows ontology = _setup_ontology(project) - label = _create_label(project, data_row, ontology, - wait_for_label_processing) + label = _create_label( + project, data_row, ontology, wait_for_label_processing + ) yield [project, dataset, data_row, label] @@ -710,15 +762,16 @@ def configured_batch_project_with_label(project, dataset, data_row, @pytest.fixture -def configured_batch_project_with_multiple_datarows(project, dataset, data_rows, - wait_for_label_processing): +def configured_batch_project_with_multiple_datarows( + project, dataset, data_rows, wait_for_label_processing +): """Project with a batch having multiple datarows Project contains an ontology with 1 bbox tool Additionally includes a create_label method for any needed extra labels """ global_keys = [dr.global_key for dr in data_rows] - batch_name = f'batch {uuid.uuid4()}' + batch_name = f"batch {uuid.uuid4()}" project.create_batch(batch_name, global_keys=global_keys) ontology = _setup_ontology(project) @@ -732,15 +785,16 @@ def configured_batch_project_with_multiple_datarows(project, dataset, data_rows, @pytest.fixture -def configured_batch_project_for_labeling_service(project, - data_row_and_global_key): +def configured_batch_project_for_labeling_service( + project, data_row_and_global_key +): """Project with a batch having multiple datarows Project contains an ontology with 1 bbox tool Additionally includes a create_label method for any needed extra labels """ global_keys = [data_row_and_global_key[1]] - batch_name = f'batch {uuid.uuid4()}' + batch_name = f"batch {uuid.uuid4()}" project.create_batch(batch_name, global_keys=global_keys) _setup_ontology(project) @@ -830,12 +884,9 @@ def video_data(client, rand_gen, video_data_row, wait_for_data_row_processing): def create_video_data_row(rand_gen): return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4", - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4-{rand_gen(str)}", - "media_type": - "VIDEO", + "row_data": "https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4", + "global_key": f"https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4-{rand_gen(str)}", + "media_type": "VIDEO", } @@ -857,25 +908,25 @@ def video_data_row(rand_gen): class ExportV2Helpers: - @classmethod - def run_project_export_v2_task(cls, - project, - num_retries=5, - task_name=None, - filters={}, - params={}): + def run_project_export_v2_task( + cls, project, num_retries=5, task_name=None, filters={}, params={} + ): task = None - params = params if params else { - "project_details": True, - "performance_details": False, - "data_row_details": True, - "label_details": True - } - while (num_retries > 0): - task = project.export_v2(task_name=task_name, - filters=filters, - params=params) + params = ( + params + if params + else { + "project_details": True, + "performance_details": False, + "data_row_details": True, + "label_details": True, + } + ) + while num_retries > 0: + task = project.export_v2( + task_name=task_name, filters=filters, params=params + ) task.wait_till_done() assert task.status == "COMPLETE" assert task.errors is None @@ -887,21 +938,19 @@ def run_project_export_v2_task(cls, return task.result @classmethod - def run_dataset_export_v2_task(cls, - dataset, - num_retries=5, - task_name=None, - filters={}, - params={}): + def run_dataset_export_v2_task( + cls, dataset, num_retries=5, task_name=None, filters={}, params={} + ): task = None - params = params if params else { - "performance_details": False, - "label_details": True - } - while (num_retries > 0): - task = dataset.export_v2(task_name=task_name, - filters=filters, - params=params) + params = ( + params + if params + else {"performance_details": False, "label_details": True} + ) + while num_retries > 0: + task = dataset.export_v2( + task_name=task_name, filters=filters, params=params + ) task.wait_till_done() assert task.status == "COMPLETE" assert task.errors is None @@ -914,23 +963,20 @@ def run_dataset_export_v2_task(cls, return task.result @classmethod - def run_catalog_export_v2_task(cls, - client, - num_retries=5, - task_name=None, - filters={}, - params={}): + def run_catalog_export_v2_task( + cls, client, num_retries=5, task_name=None, filters={}, params={} + ): task = None - params = params if params else { - "performance_details": False, - "label_details": True - } + params = ( + params + if params + else {"performance_details": False, "label_details": True} + ) catalog = client.get_catalog() - while (num_retries > 0): - - task = catalog.export_v2(task_name=task_name, - filters=filters, - params=params) + while num_retries > 0: + task = catalog.export_v2( + task_name=task_name, filters=filters, params=params + ) task.wait_till_done() assert task.status == "COMPLETE" assert task.errors is None @@ -956,9 +1002,10 @@ def big_dataset_data_row_ids(big_dataset: Dataset): yield [dr.json["data_row"]["id"] for dr in stream] -@pytest.fixture(scope='function') -def dataset_with_invalid_data_rows(unique_dataset: Dataset, - upload_invalid_data_rows_for_dataset): +@pytest.fixture(scope="function") +def dataset_with_invalid_data_rows( + unique_dataset: Dataset, upload_invalid_data_rows_for_dataset +): upload_invalid_data_rows_for_dataset(unique_dataset) yield unique_dataset @@ -966,22 +1013,25 @@ def dataset_with_invalid_data_rows(unique_dataset: Dataset, @pytest.fixture def upload_invalid_data_rows_for_dataset(): - def _upload_invalid_data_rows_for_dataset(dataset: Dataset): - task = dataset.create_data_rows([ - { - "row_data": 'gs://invalid-bucket/example.png', # forbidden - "external_id": "image-without-access.jpg" - }, - ] * 2) + task = dataset.create_data_rows( + [ + { + "row_data": "gs://invalid-bucket/example.png", # forbidden + "external_id": "image-without-access.jpg", + }, + ] + * 2 + ) task.wait_till_done() return _upload_invalid_data_rows_for_dataset @pytest.fixture -def configured_project(project_with_empty_ontology, initial_dataset, rand_gen, - image_url): +def configured_project( + project_with_empty_ontology, initial_dataset, rand_gen, image_url +): dataset = initial_dataset data_row_id = dataset.create_data_row(row_data=image_url).uid project = project_with_empty_ontology @@ -989,7 +1039,7 @@ def configured_project(project_with_empty_ontology, initial_dataset, rand_gen, batch = project.create_batch( rand_gen(str), [data_row_id], # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) + 5, # priority between 1(Highest) - 5(lowest) ) project.data_row_ids = [data_row_id] @@ -1002,18 +1052,23 @@ def configured_project(project_with_empty_ontology, initial_dataset, rand_gen, def project_with_empty_ontology(project): editor = list( project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] + where=LabelingFrontend.name == "editor" + ) + )[0] empty_ontology = {"tools": [], "classifications": []} project.setup(editor, empty_ontology) yield project @pytest.fixture -def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, - image_url): - project = client.create_project(name=rand_gen(str), - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) +def configured_project_with_complex_ontology( + client, initial_dataset, rand_gen, image_url +): + project = client.create_project( + name=rand_gen(str), + queue_mode=QueueMode.Batch, + media_type=MediaType.Image, + ) dataset = initial_dataset data_row = dataset.create_data_row(row_data=image_url) data_row_ids = [data_row.uid] @@ -1021,13 +1076,15 @@ def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, project.create_batch( rand_gen(str), data_row_ids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) + 5, # priority between 1(Highest) - 5(lowest) ) project.data_row_ids = data_row_ids editor = list( project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] + where=LabelingFrontend.name == "editor" + ) + )[0] ontology = OntologyBuilder() tools = [ @@ -1035,24 +1092,29 @@ def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, Tool(tool=Tool.Type.LINE, name="test-line-class"), Tool(tool=Tool.Type.POINT, name="test-point-class"), Tool(tool=Tool.Type.POLYGON, name="test-polygon-class"), - Tool(tool=Tool.Type.NER, name="test-ner-class") + Tool(tool=Tool.Type.NER, name="test-ner-class"), ] options = [ Option(value="first option answer"), Option(value="second option answer"), - Option(value="third option answer") + Option(value="third option answer"), ] classifications = [ - Classification(class_type=Classification.Type.TEXT, - name="test-text-class"), - Classification(class_type=Classification.Type.RADIO, - name="test-radio-class", - options=options), - Classification(class_type=Classification.Type.CHECKLIST, - name="test-checklist-class", - options=options) + Classification( + class_type=Classification.Type.TEXT, name="test-text-class" + ), + Classification( + class_type=Classification.Type.RADIO, + name="test-radio-class", + options=options, + ), + Classification( + class_type=Classification.Type.CHECKLIST, + name="test-checklist-class", + options=options, + ), ] for t in tools: @@ -1070,7 +1132,6 @@ def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, @pytest.fixture def embedding(client: Client, environ): - uuid_str = uuid.uuid4().hex time.sleep(randint(1, 5)) embedding = client.create_embedding(f"sdk-int-{uuid_str}", 8) @@ -1085,13 +1146,16 @@ def valid_model_id(): @pytest.fixture -def requested_labeling_service(rand_gen, - live_chat_evaluation_project_with_new_dataset, - chat_evaluation_ontology, model_config): +def requested_labeling_service( + rand_gen, + live_chat_evaluation_project_with_new_dataset, + chat_evaluation_ontology, + model_config, +): project = live_chat_evaluation_project_with_new_dataset project.connect_ontology(chat_evaluation_ontology) - project.upsert_instructions('tests/integration/media/sample_pdf.pdf') + project.upsert_instructions("tests/integration/media/sample_pdf.pdf") labeling_service = project.get_labeling_service() project.add_model_config(model_config.uid) diff --git a/libs/labelbox/tests/data/annotation_import/conftest.py b/libs/labelbox/tests/data/annotation_import/conftest.py index 370af0517..39cede0bb 100644 --- a/libs/labelbox/tests/data/annotation_import/conftest.py +++ b/libs/labelbox/tests/data/annotation_import/conftest.py @@ -15,6 +15,7 @@ from labelbox.schema.annotation_import import LabelImport, AnnotationImportState from pytest import FixtureRequest from contextlib import suppress + """ The main fixtures of this library are configured_project and configured_project_by_global_key. Both fixtures generate data rows with a parametrize media type. They create the amount of data rows equal to the DATA_ROW_COUNT variable below. The data rows are generated with a factory fixture that returns a function that allows you to pass a global key. The ontologies are generated normalized and based on the MediaType given (i.e. only features supported by MediaType are created). This ontology is later used to obtain the correct annotations with the prediction_id_mapping and corresponding inferences. Each data row will have all possible annotations attached supported for the MediaType. """ @@ -26,15 +27,11 @@ @pytest.fixture(scope="module", autouse=True) def video_data_row_factory(): - def video_data_row(global_key): return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4", - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4-{global_key}", - "media_type": - "VIDEO", + "row_data": "https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4", + "global_key": f"https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4-{global_key}", + "media_type": "VIDEO", } return video_data_row @@ -42,15 +39,11 @@ def video_data_row(global_key): @pytest.fixture(scope="module", autouse=True) def audio_data_row_factory(): - def audio_data_row(global_key): return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/audio-sample-data/sample-audio-1.mp3", - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/audio-sample-data/sample-audio-1.mp3-{global_key}", - "media_type": - "AUDIO", + "row_data": "https://storage.googleapis.com/labelbox-datasets/audio-sample-data/sample-audio-1.mp3", + "global_key": f"https://storage.googleapis.com/labelbox-datasets/audio-sample-data/sample-audio-1.mp3-{global_key}", + "media_type": "AUDIO", } return audio_data_row @@ -58,13 +51,10 @@ def audio_data_row(global_key): @pytest.fixture(scope="module", autouse=True) def conversational_data_row_factory(): - def conversational_data_row(global_key): return { - "row_data": - "https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json", - "global_key": - f"https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json-{global_key}", + "row_data": "https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json", + "global_key": f"https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json-{global_key}", } return conversational_data_row @@ -72,15 +62,11 @@ def conversational_data_row(global_key): @pytest.fixture(scope="module", autouse=True) def dicom_data_row_factory(): - def dicom_data_row(global_key): return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/dicom-sample-data/sample-dicom-1.dcm", - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/dicom-sample-data/sample-dicom-1.dcm-{global_key}", - "media_type": - "DICOM", + "row_data": "https://storage.googleapis.com/labelbox-datasets/dicom-sample-data/sample-dicom-1.dcm", + "global_key": f"https://storage.googleapis.com/labelbox-datasets/dicom-sample-data/sample-dicom-1.dcm-{global_key}", + "media_type": "DICOM", } return dicom_data_row @@ -88,27 +74,20 @@ def dicom_data_row(global_key): @pytest.fixture(scope="module", autouse=True) def geospatial_data_row_factory(): - def geospatial_data_row(global_key): return { "row_data": { - "tile_layer_url": - "https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/{z}/{x}/{y}.png", + "tile_layer_url": "https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/{z}/{x}/{y}.png", "bounds": [ [19.405662413477728, -99.21052827588443], [19.400498983095076, -99.20534818927473], ], - "min_zoom": - 12, - "max_zoom": - 20, - "epsg": - "EPSG4326", + "min_zoom": 12, + "max_zoom": 20, + "epsg": "EPSG4326", }, - "global_key": - f"https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/z/x/y.png-{global_key}", - "media_type": - "TMS_GEO", + "global_key": f"https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/z/x/y.png-{global_key}", + "media_type": "TMS_GEO", } return geospatial_data_row @@ -116,13 +95,10 @@ def geospatial_data_row(global_key): @pytest.fixture(scope="module", autouse=True) def html_data_row_factory(): - def html_data_row(global_key): return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/html_sample_data/sample_html_1.html", - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/html_sample_data/sample_html_1.html-{global_key}", + "row_data": "https://storage.googleapis.com/labelbox-datasets/html_sample_data/sample_html_1.html", + "global_key": f"https://storage.googleapis.com/labelbox-datasets/html_sample_data/sample_html_1.html-{global_key}", } return html_data_row @@ -130,15 +106,11 @@ def html_data_row(global_key): @pytest.fixture(scope="module", autouse=True) def image_data_row_factory(): - def image_data_row(global_key): return { - "row_data": - "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg", - "global_key": - f"https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-{global_key}", - "media_type": - "IMAGE", + "row_data": "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg", + "global_key": f"https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-{global_key}", + "media_type": "IMAGE", } return image_data_row @@ -146,19 +118,14 @@ def image_data_row(global_key): @pytest.fixture(scope="module", autouse=True) def document_data_row_factory(): - def document_data_row(global_key): return { "row_data": { - "pdf_url": - "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf", - "text_layer_url": - "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483-lb-textlayer.json", + "pdf_url": "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf", + "text_layer_url": "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483-lb-textlayer.json", }, - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf-{global_key}", - "media_type": - "PDF", + "global_key": f"https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf-{global_key}", + "media_type": "PDF", } return document_data_row @@ -166,15 +133,11 @@ def document_data_row(global_key): @pytest.fixture(scope="module", autouse=True) def text_data_row_factory(): - def text_data_row(global_key): return { - "row_data": - "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-2.txt", - "global_key": - f"https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-2.txt-{global_key}", - "media_type": - "TEXT", + "row_data": "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-2.txt", + "global_key": f"https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-2.txt-{global_key}", + "media_type": "TEXT", } return text_data_row @@ -182,13 +145,10 @@ def text_data_row(global_key): @pytest.fixture(scope="module", autouse=True) def llm_human_preference_data_row_factory(): - def llm_human_preference_data_row(global_key): return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/sdk_test/llm_prompt_response_conv.json", - "global_key": - global_key, + "row_data": "https://storage.googleapis.com/labelbox-datasets/sdk_test/llm_prompt_response_conv.json", + "global_key": global_key, } return llm_human_preference_data_row @@ -224,60 +184,50 @@ def normalized_ontology_by_media_type(): """Returns NDJSON of ontology based on media type""" bbox_tool_with_nested_text = { - "required": - False, - "name": - "bbox_tool_with_nested_text", - "tool": - "rectangle", - "color": - "#a23030", - "classifications": [{ - "required": - False, - "instructions": - "nested", - "name": - "nested", - "type": - "radio", - "options": [{ - "label": - "radio_value_1", - "value": - "radio_value_1", + "required": False, + "name": "bbox_tool_with_nested_text", + "tool": "rectangle", + "color": "#a23030", + "classifications": [ + { + "required": False, + "instructions": "nested", + "name": "nested", + "type": "radio", "options": [ { - "required": - False, - "instructions": - "nested_checkbox", - "name": - "nested_checkbox", - "type": - "checklist", + "label": "radio_value_1", + "value": "radio_value_1", "options": [ { - "label": "nested_checkbox_option_1", - "value": "nested_checkbox_option_1", - "options": [], + "required": False, + "instructions": "nested_checkbox", + "name": "nested_checkbox", + "type": "checklist", + "options": [ + { + "label": "nested_checkbox_option_1", + "value": "nested_checkbox_option_1", + "options": [], + }, + { + "label": "nested_checkbox_option_2", + "value": "nested_checkbox_option_2", + }, + ], }, { - "label": "nested_checkbox_option_2", - "value": "nested_checkbox_option_2", + "required": False, + "instructions": "nested_text", + "name": "nested_text", + "type": "text", + "options": [], }, ], }, - { - "required": False, - "instructions": "nested_text", - "name": "nested_text", - "type": "text", - "options": [], - }, ], - },], - }], + } + ], } bbox_tool = { @@ -331,44 +281,35 @@ def normalized_ontology_by_media_type(): "classifications": [], } checklist = { - "required": - False, - "instructions": - "checklist", - "name": - "checklist", - "type": - "checklist", + "required": False, + "instructions": "checklist", + "name": "checklist", + "type": "checklist", "options": [ { "label": "first_checklist_answer", - "value": "first_checklist_answer" + "value": "first_checklist_answer", }, { "label": "second_checklist_answer", - "value": "second_checklist_answer" + "value": "second_checklist_answer", }, ], } checklist_index = { - "required": - False, - "instructions": - "checklist_index", - "name": - "checklist_index", - "type": - "checklist", - "scope": - "index", + "required": False, + "instructions": "checklist_index", + "name": "checklist_index", + "type": "checklist", + "scope": "index", "options": [ { "label": "first_checklist_answer", - "value": "first_checklist_answer" + "value": "first_checklist_answer", }, { "label": "second_checklist_answer", - "value": "second_checklist_answer" + "value": "second_checklist_answer", }, ], } @@ -388,14 +329,10 @@ def normalized_ontology_by_media_type(): "options": [], } radio = { - "required": - False, - "instructions": - "radio", - "name": - "radio", - "type": - "radio", + "required": False, + "instructions": "radio", + "name": "radio", + "type": "radio", "options": [ { "label": "first_radio_answer", @@ -418,39 +355,45 @@ def normalized_ontology_by_media_type(): "maxCharacters": 50, "minCharacters": 1, "schemaNodeId": None, - "type": "prompt" + "type": "prompt", } response_radio = { "instructions": "radio-response", "name": "radio-response", - "options": [{ - "label": "first_radio_answer", - "value": "first_radio_answer", - "options": [] - }, { - "label": "second_radio_answer", - "value": "second_radio_answer", - "options": [] - }], + "options": [ + { + "label": "first_radio_answer", + "value": "first_radio_answer", + "options": [], + }, + { + "label": "second_radio_answer", + "value": "second_radio_answer", + "options": [], + }, + ], "required": True, - "type": "response-radio" + "type": "response-radio", } response_checklist = { "instructions": "checklist-response", "name": "checklist-response", - "options": [{ - "label": "first_checklist_answer", - "value": "first_checklist_answer", - "options": [] - }, { - "label": "second_checklist_answer", - "value": "second_checklist_answer", - "options": [] - }], + "options": [ + { + "label": "first_checklist_answer", + "value": "first_checklist_answer", + "options": [], + }, + { + "label": "second_checklist_answer", + "value": "second_checklist_answer", + "options": [], + }, + ], "required": True, - "type": "response-checklist" + "type": "response-checklist", } response_text = { @@ -459,7 +402,7 @@ def normalized_ontology_by_media_type(): "minCharacters": 1, "name": "response-text", "required": True, - "type": "response-text" + "type": "response-text", } return { @@ -476,7 +419,7 @@ def normalized_ontology_by_media_type(): checklist, free_form_text, radio, - ] + ], }, MediaType.Text: { "tools": [entity_tool], @@ -484,7 +427,7 @@ def normalized_ontology_by_media_type(): checklist, free_form_text, radio, - ] + ], }, MediaType.Video: { "tools": [ @@ -495,9 +438,12 @@ def normalized_ontology_by_media_type(): raster_segmentation_tool, ], "classifications": [ - checklist, free_form_text, radio, checklist_index, - free_form_text_index - ] + checklist, + free_form_text, + radio, + checklist_index, + free_form_text_index, + ], }, MediaType.Geospatial_Tile: { "tools": [ @@ -511,7 +457,7 @@ def normalized_ontology_by_media_type(): checklist, free_form_text, radio, - ] + ], }, MediaType.Document: { "tools": [entity_tool, bbox_tool, bbox_tool_with_nested_text], @@ -519,7 +465,7 @@ def normalized_ontology_by_media_type(): checklist, free_form_text, radio, - ] + ], }, MediaType.Audio: { "tools": [], @@ -527,7 +473,7 @@ def normalized_ontology_by_media_type(): checklist, free_form_text, radio, - ] + ], }, MediaType.Html: { "tools": [], @@ -535,34 +481,42 @@ def normalized_ontology_by_media_type(): checklist, free_form_text, radio, - ] + ], }, MediaType.Dicom: { "tools": [raster_segmentation_tool, polyline_tool], - "classifications": [] + "classifications": [], }, MediaType.Conversational: { "tools": [entity_tool], "classifications": [ - checklist, free_form_text, radio, checklist_index, - free_form_text_index - ] + checklist, + free_form_text, + radio, + checklist_index, + free_form_text_index, + ], }, MediaType.LLMPromptResponseCreation: { "tools": [], "classifications": [ - prompt_text, response_text, response_radio, response_checklist - ] + prompt_text, + response_text, + response_radio, + response_checklist, + ], }, MediaType.LLMPromptCreation: { "tools": [], - "classifications": [prompt_text] + "classifications": [prompt_text], }, OntologyKind.ResponseCreation: { "tools": [], "classifications": [ - response_text, response_radio, response_checklist - ] + response_text, + response_radio, + response_checklist, + ], }, "all": { "tools": [ @@ -581,8 +535,8 @@ def normalized_ontology_by_media_type(): free_form_text, free_form_text_index, radio, - ] - } + ], + }, } @@ -617,7 +571,7 @@ def func(project): @pytest.fixture def hardcoded_datarow_id(): - data_row_id = 'ck8q9q9qj00003g5z3q1q9q9q' + data_row_id = "ck8q9q9qj00003g5z3q1q9q9q" def get_data_row_id(): return data_row_id @@ -639,33 +593,40 @@ def get_global_key(): def _create_response_creation_project( - client: Client, rand_gen, data_row_json_by_media_type, ontology_kind, - normalized_ontology_by_media_type) -> Tuple[Project, Ontology, Dataset]: + client: Client, + rand_gen, + data_row_json_by_media_type, + ontology_kind, + normalized_ontology_by_media_type, +) -> Tuple[Project, Ontology, Dataset]: "For response creation projects" dataset = client.create_dataset(name=rand_gen(str)) project = client.create_response_creation_project( - name=f"{ontology_kind}-{rand_gen(str)}") + name=f"{ontology_kind}-{rand_gen(str)}" + ) ontology = client.create_ontology( name=f"{ontology_kind}-{rand_gen(str)}", normalized=normalized_ontology_by_media_type[ontology_kind], media_type=MediaType.Text, - ontology_kind=ontology_kind) + ontology_kind=ontology_kind, + ) project.connect_ontology(ontology) data_row_data = [] for _ in range(DATA_ROW_COUNT): - data_row_data.append(data_row_json_by_media_type[MediaType.Text]( - rand_gen(str))) + data_row_data.append( + data_row_json_by_media_type[MediaType.Text](rand_gen(str)) + ) task = dataset.create_data_rows(data_row_data) task.wait_till_done() - global_keys = [row['global_key'] for row in task.result] - data_row_ids = [row['id'] for row in task.result] + global_keys = [row["global_key"] for row in task.result] + data_row_ids = [row["id"] for row in task.result] project.create_batch( rand_gen(str), @@ -679,16 +640,15 @@ def _create_response_creation_project( @pytest.fixture -def llm_prompt_response_creation_dataset_with_data_row(client: Client, - rand_gen): +def llm_prompt_response_creation_dataset_with_data_row( + client: Client, rand_gen +): dataset = client.create_dataset(name=rand_gen(str)) global_key = str(uuid.uuid4()) convo_data = { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/conversational-sample-data/pairwise_shopping_2.json", - "global_key": - global_key + "row_data": "https://storage.googleapis.com/labelbox-datasets/conversational-sample-data/pairwise_shopping_2.json", + "global_key": global_key, } task = dataset.create_data_rows([convo_data]) @@ -700,26 +660,33 @@ def llm_prompt_response_creation_dataset_with_data_row(client: Client, def _create_prompt_response_project( - client: Client, rand_gen, media_type, normalized_ontology_by_media_type, - export_v2_test_helpers, llm_prompt_response_creation_dataset_with_data_row + client: Client, + rand_gen, + media_type, + normalized_ontology_by_media_type, + export_v2_test_helpers, + llm_prompt_response_creation_dataset_with_data_row, ) -> Tuple[Project, Ontology]: """For prompt response data row auto gen projects""" dataset = llm_prompt_response_creation_dataset_with_data_row prompt_response_project = client.create_prompt_response_generation_project( name=f"{media_type.value}-{rand_gen(str)}", dataset_id=dataset.uid, - media_type=media_type) + media_type=media_type, + ) ontology = client.create_ontology( name=f"{media_type}-{rand_gen(str)}", normalized=normalized_ontology_by_media_type[media_type], - media_type=media_type) + media_type=media_type, + ) prompt_response_project.connect_ontology(ontology) # We have to export to get data row ids result = export_v2_test_helpers.run_project_export_v2_task( - prompt_response_project) + prompt_response_project + ) data_row_ids = [dr["data_row"]["id"] for dr in result] global_keys = [dr["data_row"]["global_key"] for dr in result] @@ -731,32 +698,39 @@ def _create_prompt_response_project( def _create_project( - client: Client, rand_gen, data_row_json_by_media_type, media_type, - normalized_ontology_by_media_type) -> Tuple[Project, Ontology, Dataset]: - """ Shared function to configure project for integration tests """ + client: Client, + rand_gen, + data_row_json_by_media_type, + media_type, + normalized_ontology_by_media_type, +) -> Tuple[Project, Ontology, Dataset]: + """Shared function to configure project for integration tests""" dataset = client.create_dataset(name=rand_gen(str)) - project = client.create_project(name=f"{media_type}-{rand_gen(str)}", - media_type=media_type) + project = client.create_project( + name=f"{media_type}-{rand_gen(str)}", media_type=media_type + ) ontology = client.create_ontology( name=f"{media_type}-{rand_gen(str)}", normalized=normalized_ontology_by_media_type[media_type], - media_type=media_type) + media_type=media_type, + ) project.connect_ontology(ontology) data_row_data = [] for _ in range(DATA_ROW_COUNT): - data_row_data.append(data_row_json_by_media_type[media_type]( - rand_gen(str))) + data_row_data.append( + data_row_json_by_media_type[media_type](rand_gen(str)) + ) task = dataset.create_data_rows(data_row_data) task.wait_till_done() - global_keys = [row['global_key'] for row in task.result] - data_row_ids = [row['id'] for row in task.result] + global_keys = [row["global_key"] for row in task.result] + data_row_ids = [row["id"] for row in task.result] project.create_batch( rand_gen(str), @@ -770,29 +744,48 @@ def _create_project( @pytest.fixture -def configured_project(client: Client, rand_gen, data_row_json_by_media_type, - request: FixtureRequest, - normalized_ontology_by_media_type, - export_v2_test_helpers, - llm_prompt_response_creation_dataset_with_data_row): +def configured_project( + client: Client, + rand_gen, + data_row_json_by_media_type, + request: FixtureRequest, + normalized_ontology_by_media_type, + export_v2_test_helpers, + llm_prompt_response_creation_dataset_with_data_row, +): """Configure project for test. Request.param will contain the media type if not present will use Image MediaType. The project will have 10 data rows.""" media_type = getattr(request, "param", MediaType.Image) dataset = None - if media_type == MediaType.LLMPromptCreation or media_type == MediaType.LLMPromptResponseCreation: + if ( + media_type == MediaType.LLMPromptCreation + or media_type == MediaType.LLMPromptResponseCreation + ): project, ontology = _create_prompt_response_project( - client, rand_gen, media_type, normalized_ontology_by_media_type, + client, + rand_gen, + media_type, + normalized_ontology_by_media_type, export_v2_test_helpers, - llm_prompt_response_creation_dataset_with_data_row) + llm_prompt_response_creation_dataset_with_data_row, + ) elif media_type == OntologyKind.ResponseCreation: project, ontology, dataset = _create_response_creation_project( - client, rand_gen, data_row_json_by_media_type, media_type, - normalized_ontology_by_media_type) + client, + rand_gen, + data_row_json_by_media_type, + media_type, + normalized_ontology_by_media_type, + ) else: project, ontology, dataset = _create_project( - client, rand_gen, data_row_json_by_media_type, media_type, - normalized_ontology_by_media_type) + client, + rand_gen, + data_row_json_by_media_type, + media_type, + normalized_ontology_by_media_type, + ) yield project @@ -805,28 +798,46 @@ def configured_project(client: Client, rand_gen, data_row_json_by_media_type, @pytest.fixture() -def configured_project_by_global_key(client: Client, rand_gen, - data_row_json_by_media_type, - request: FixtureRequest, - normalized_ontology_by_media_type, - export_v2_test_helpers): +def configured_project_by_global_key( + client: Client, + rand_gen, + data_row_json_by_media_type, + request: FixtureRequest, + normalized_ontology_by_media_type, + export_v2_test_helpers, +): """Does the same thing as configured project but with global keys focus.""" media_type = getattr(request, "param", MediaType.Image) dataset = None - if media_type == MediaType.LLMPromptCreation or media_type == MediaType.LLMPromptResponseCreation: + if ( + media_type == MediaType.LLMPromptCreation + or media_type == MediaType.LLMPromptResponseCreation + ): project, ontology = _create_prompt_response_project( - client, rand_gen, media_type, normalized_ontology_by_media_type, - export_v2_test_helpers) + client, + rand_gen, + media_type, + normalized_ontology_by_media_type, + export_v2_test_helpers, + ) elif media_type == OntologyKind.ResponseCreation: project, ontology, dataset = _create_response_creation_project( - client, rand_gen, data_row_json_by_media_type, media_type, - normalized_ontology_by_media_type) + client, + rand_gen, + data_row_json_by_media_type, + media_type, + normalized_ontology_by_media_type, + ) else: project, ontology, dataset = _create_project( - client, rand_gen, data_row_json_by_media_type, media_type, - normalized_ontology_by_media_type) + client, + rand_gen, + data_row_json_by_media_type, + media_type, + normalized_ontology_by_media_type, + ) yield project @@ -839,25 +850,42 @@ def configured_project_by_global_key(client: Client, rand_gen, @pytest.fixture(scope="module") -def module_project(client: Client, rand_gen, data_row_json_by_media_type, - request: FixtureRequest, normalized_ontology_by_media_type): +def module_project( + client: Client, + rand_gen, + data_row_json_by_media_type, + request: FixtureRequest, + normalized_ontology_by_media_type, +): """Generates a image project that scopes to the test module(file). Used to reduce api calls.""" media_type = getattr(request, "param", MediaType.Image) media_type = getattr(request, "param", MediaType.Image) dataset = None - if media_type == MediaType.LLMPromptCreation or media_type == MediaType.LLMPromptResponseCreation: + if ( + media_type == MediaType.LLMPromptCreation + or media_type == MediaType.LLMPromptResponseCreation + ): project, ontology = _create_prompt_response_project( - client, rand_gen, media_type, normalized_ontology_by_media_type) + client, rand_gen, media_type, normalized_ontology_by_media_type + ) elif media_type == OntologyKind.ResponseCreation: project, ontology, dataset = _create_response_creation_project( - client, rand_gen, data_row_json_by_media_type, media_type, - normalized_ontology_by_media_type) + client, + rand_gen, + data_row_json_by_media_type, + media_type, + normalized_ontology_by_media_type, + ) else: project, ontology, dataset = _create_project( - client, rand_gen, data_row_json_by_media_type, media_type, - normalized_ontology_by_media_type) + client, + rand_gen, + data_row_json_by_media_type, + media_type, + normalized_ontology_by_media_type, + ) yield project @@ -872,17 +900,17 @@ def module_project(client: Client, rand_gen, data_row_json_by_media_type, @pytest.fixture def prediction_id_mapping(request, normalized_ontology_by_media_type): """Creates the base of annotation based on tools inside project ontology. We would want only annotations supported for the MediaType of the ontology and project. Annotations are generated for each data row created later be combined inside the test file. This serves as the base fixture for all the interference (annotations) fixture. This fixtures supports a few strategies: - + Integration test: configured_project: generates data rows with data row id focus. configured_project_by_global_key: generates data rows with global key focus. module_configured_project: configured project but scoped to test module. Unit tests - Individuals can supply hard-coded data row ids or global keys without configured a project must include a media type fixture to get the appropriate annotations. - - Each strategy provides a few items. - + Individuals can supply hard-coded data row ids or global keys without configured a project must include a media type fixture to get the appropriate annotations. + + Each strategy provides a few items. + Labelbox Project (unit testing strategies do not make api calls so will have None for project) Data row identifiers (ids the annotation uses) Ontology: normalized ontology @@ -890,23 +918,23 @@ def prediction_id_mapping(request, normalized_ontology_by_media_type): if "configured_project" in request.fixturenames: project = request.getfixturevalue("configured_project") - data_row_identifiers = [{ - "id": data_row_id - } for data_row_id in project.data_row_ids] + data_row_identifiers = [ + {"id": data_row_id} for data_row_id in project.data_row_ids + ] ontology = project.ontology().normalized elif "configured_project_by_global_key" in request.fixturenames: project = request.getfixturevalue("configured_project_by_global_key") - data_row_identifiers = [{ - "globalKey": global_key - } for global_key in project.global_keys] + data_row_identifiers = [ + {"globalKey": global_key} for global_key in project.global_keys + ] ontology = project.ontology().normalized elif "module_project" in request.fixturenames: project = request.getfixturevalue("module_project") - data_row_identifiers = [{ - "id": data_row_id - } for data_row_id in project.data_row_ids] + data_row_identifiers = [ + {"id": data_row_id} for data_row_id in project.data_row_ids + ] ontology = project.ontology().normalized elif "hardcoded_datarow_id" in request.fixturenames: @@ -915,9 +943,9 @@ def prediction_id_mapping(request, normalized_ontology_by_media_type): project = None media_type = request.getfixturevalue("media_type") ontology = normalized_ontology_by_media_type[media_type] - data_row_identifiers = [{ - "id": request.getfixturevalue("hardcoded_datarow_id")() - }] + data_row_identifiers = [ + {"id": request.getfixturevalue("hardcoded_datarow_id")()} + ] elif "hardcoded_global_key" in request.fixturenames: if "media_type" not in request.fixturenames: @@ -925,9 +953,9 @@ def prediction_id_mapping(request, normalized_ontology_by_media_type): project = None media_type = request.getfixturevalue("media_type") ontology = normalized_ontology_by_media_type[media_type] - data_row_identifiers = [{ - "globalKey": request.getfixturevalue("hardcoded_global_key")() - }] + data_row_identifiers = [ + {"globalKey": request.getfixturevalue("hardcoded_global_key")()} + ] # Used for tests that need access to every ontology else: @@ -939,21 +967,25 @@ def prediction_id_mapping(request, normalized_ontology_by_media_type): base_annotations = [] for data_row_identifier in data_row_identifiers: base_annotation = {} - for feature in (ontology["tools"] + ontology["classifications"]): + for feature in ontology["tools"] + ontology["classifications"]: if "tool" in feature: - feature_type = (feature["tool"] if feature["classifications"] - == [] else f"{feature['tool']}_nested" - ) # tool vs nested classification tool + feature_type = ( + feature["tool"] + if feature["classifications"] == [] + else f"{feature['tool']}_nested" + ) # tool vs nested classification tool else: - feature_type = (feature["type"] if "scope" not in feature else - f"{feature['type']}_{feature['scope']}" - ) # checklist vs indexed checklist + feature_type = ( + feature["type"] + if "scope" not in feature + else f"{feature['type']}_{feature['scope']}" + ) # checklist vs indexed checklist base_annotation[feature_type] = { "uuid": str(uuid.uuid4()), "name": feature["name"], "tool": feature, - "dataRow": data_row_identifier + "dataRow": data_row_identifier, } base_annotations.append(base_annotation) @@ -968,26 +1000,16 @@ def polygon_inference(prediction_id_mapping): if "polygon" not in feature: continue polygon = feature["polygon"].copy() - polygon.update({ - "polygon": [ - { - "x": 147.692, - "y": 118.154 - }, - { - "x": 142.769, - "y": 104.923 - }, - { - "x": 57.846, - "y": 118.769 - }, - { - "x": 28.308, - "y": 169.846 - }, - ] - }) + polygon.update( + { + "polygon": [ + {"x": 147.692, "y": 118.154}, + {"x": 142.769, "y": 104.923}, + {"x": 57.846, "y": 118.769}, + {"x": 28.308, "y": 169.846}, + ] + } + ) del polygon["tool"] polygons.append(polygon) return polygons @@ -1000,14 +1022,11 @@ def rectangle_inference(prediction_id_mapping): if "rectangle" not in feature: continue rectangle = feature["rectangle"].copy() - rectangle.update({ - "bbox": { - "top": 48, - "left": 58, - "height": 65, - "width": 12 - }, - }) + rectangle.update( + { + "bbox": {"top": 48, "left": 58, "height": 65, "width": 12}, + } + ) del rectangle["tool"] rectangles.append(rectangle) return rectangles @@ -1020,34 +1039,35 @@ def rectangle_inference_with_confidence(prediction_id_mapping): if "rectangle_nested" not in feature: continue rectangle = feature["rectangle_nested"].copy() - rectangle.update({ - "bbox": { - "top": 48, - "left": 58, - "height": 65, - "width": 12 - }, - "classifications": [{ - "name": rectangle["tool"]["classifications"][0]["name"], - "answer": { - "name": - rectangle["tool"]["classifications"][0]["options"][0] - ["value"], - "classifications": [{ - "name": - rectangle["tool"]["classifications"][0]["options"] - [0]["options"][1]["name"], - "answer": - "nested answer", - }], - }, - }], - }) + rectangle.update( + { + "bbox": {"top": 48, "left": 58, "height": 65, "width": 12}, + "classifications": [ + { + "name": rectangle["tool"]["classifications"][0]["name"], + "answer": { + "name": rectangle["tool"]["classifications"][0][ + "options" + ][0]["value"], + "classifications": [ + { + "name": rectangle["tool"][ + "classifications" + ][0]["options"][0]["options"][1]["name"], + "answer": "nested answer", + } + ], + }, + } + ], + } + ) rectangle.update({"confidence": 0.9}) rectangle["classifications"][0]["answer"]["confidence"] = 0.8 rectangle["classifications"][0]["answer"]["classifications"][0][ - "confidence"] = 0.7 + "confidence" + ] = 0.7 del rectangle["tool"] rectangles.append(rectangle) @@ -1071,15 +1091,14 @@ def line_inference(prediction_id_mapping): if "line" not in feature: continue line = feature["line"].copy() - line.update({ - "line": [{ - "x": 147.692, - "y": 118.154 - }, { - "x": 150.692, - "y": 160.154 - }] - }) + line.update( + { + "line": [ + {"x": 147.692, "y": 118.154}, + {"x": 150.692, "y": 160.154}, + ] + } + ) del line["tool"] lines.append(line) return lines @@ -1093,24 +1112,20 @@ def line_inference_v2(prediction_id_mapping): continue line = feature["line"].copy() line_data = { - "groupKey": - "axial", - "segments": [{ - "keyframes": [{ - "frame": - 1, - "line": [ - { - "x": 147.692, - "y": 118.154 - }, + "groupKey": "axial", + "segments": [ + { + "keyframes": [ { - "x": 150.692, - "y": 160.154 - }, - ], - }] - },], + "frame": 1, + "line": [ + {"x": 147.692, "y": 118.154}, + {"x": 150.692, "y": 160.154}, + ], + } + ] + }, + ], } line.update(line_data) del line["tool"] @@ -1151,13 +1166,12 @@ def entity_inference_index(prediction_id_mapping): if "named-entity" not in feature: continue entity = feature["named-entity"].copy() - entity.update({ - "location": { - "start": 0, - "end": 8 - }, - "messageId": "0", - }) + entity.update( + { + "location": {"start": 0, "end": 8}, + "messageId": "0", + } + ) del entity["tool"] named_entities.append(entity) return named_entities @@ -1171,20 +1185,22 @@ def entity_inference_document(prediction_id_mapping): continue entity = feature["named-entity"].copy() document_selections = { - "textSelections": [{ - "tokenIds": [ - "3f984bf3-1d61-44f5-b59a-9658a2e3440f", - "3bf00b56-ff12-4e52-8cc1-08dbddb3c3b8", - "6e1c3420-d4b7-4c5a-8fd6-ead43bf73d80", - "87a43d32-af76-4a1d-b262-5c5f4d5ace3a", - "e8606e8a-dfd9-4c49-a635-ad5c879c75d0", - "67c7c19e-4654-425d-bf17-2adb8cf02c30", - "149c5e80-3e07-49a7-ab2d-29ddfe6a38fa", - "b0e94071-2187-461e-8e76-96c58738a52c", - ], - "groupId": "2f4336f4-a07e-4e0a-a9e1-5629b03b719b", - "page": 1, - }] + "textSelections": [ + { + "tokenIds": [ + "3f984bf3-1d61-44f5-b59a-9658a2e3440f", + "3bf00b56-ff12-4e52-8cc1-08dbddb3c3b8", + "6e1c3420-d4b7-4c5a-8fd6-ead43bf73d80", + "87a43d32-af76-4a1d-b262-5c5f4d5ace3a", + "e8606e8a-dfd9-4c49-a635-ad5c879c75d0", + "67c7c19e-4654-425d-bf17-2adb8cf02c30", + "149c5e80-3e07-49a7-ab2d-29ddfe6a38fa", + "b0e94071-2187-461e-8e76-96c58738a52c", + ], + "groupId": "2f4336f4-a07e-4e0a-a9e1-5629b03b719b", + "page": 1, + } + ] } entity.update(document_selections) del entity["tool"] @@ -1199,13 +1215,14 @@ def segmentation_inference(prediction_id_mapping): if "superpixel" not in feature: continue segmentation = feature["superpixel"].copy() - segmentation.update({ - "mask": { - "instanceURI": - "https://storage.googleapis.com/labelbox-datasets/image_sample_data/raster_seg.png", - "colorRGB": (255, 255, 255), + segmentation.update( + { + "mask": { + "instanceURI": "https://storage.googleapis.com/labelbox-datasets/image_sample_data/raster_seg.png", + "colorRGB": (255, 255, 255), + } } - }) + ) del segmentation["tool"] superpixel_masks.append(segmentation) return superpixel_masks @@ -1218,13 +1235,12 @@ def segmentation_inference_rle(prediction_id_mapping): if "superpixel" not in feature: continue segmentation = feature["superpixel"].copy() - segmentation.update({ - "uuid": str(uuid.uuid4()), - "mask": { - "size": [10, 10], - "counts": [1, 0, 10, 100] - }, - }) + segmentation.update( + { + "uuid": str(uuid.uuid4()), + "mask": {"size": [10, 10], "counts": [1, 0, 10, 100]}, + } + ) del segmentation["tool"] superpixel_masks.append(segmentation) return superpixel_masks @@ -1237,12 +1253,14 @@ def segmentation_inference_png(prediction_id_mapping): if "superpixel" not in feature: continue segmentation = feature["superpixel"].copy() - segmentation.update({ - "uuid": str(uuid.uuid4()), - "mask": { - "png": "somedata", - }, - }) + segmentation.update( + { + "uuid": str(uuid.uuid4()), + "mask": { + "png": "somedata", + }, + } + ) del segmentation["tool"] superpixel_masks.append(segmentation) return superpixel_masks @@ -1255,13 +1273,14 @@ def checklist_inference(prediction_id_mapping): if "checklist" not in feature: continue checklist = feature["checklist"].copy() - checklist.update({ - "answers": [{ - "name": "first_checklist_answer" - }, { - "name": "second_checklist_answer" - }] - }) + checklist.update( + { + "answers": [ + {"name": "first_checklist_answer"}, + {"name": "second_checklist_answer"}, + ] + } + ) del checklist["tool"] checklists.append(checklist) return checklists @@ -1274,14 +1293,15 @@ def checklist_inference_index(prediction_id_mapping): if "checklist_index" not in feature: return None checklist = feature["checklist_index"].copy() - checklist.update({ - "answers": [{ - "name": "first_checklist_answer" - }, { - "name": "second_checklist_answer" - }], - "messageId": "0", - }) + checklist.update( + { + "answers": [ + {"name": "first_checklist_answer"}, + {"name": "second_checklist_answer"}, + ], + "messageId": "0", + } + ) del checklist["tool"] checklists.append(checklist) return checklists @@ -1307,11 +1327,11 @@ def radio_response_inference(prediction_id_mapping): if "response-radio" not in feature: continue response_radio = feature["response-radio"].copy() - response_radio.update({ - "answer": { - "name": "first_radio_answer" - }, - }) + response_radio.update( + { + "answer": {"name": "first_radio_answer"}, + } + ) del response_radio["tool"] response_radios.append(response_radio) return response_radios @@ -1324,13 +1344,14 @@ def checklist_response_inference(prediction_id_mapping): if "response-checklist" not in feature: continue response_checklist = feature["response-checklist"].copy() - response_checklist.update({ - "answer": [{ - "name": "first_checklist_answer" - }, { - "name": "second_checklist_answer" - }] - }) + response_checklist.update( + { + "answer": [ + {"name": "first_checklist_answer"}, + {"name": "second_checklist_answer"}, + ] + } + ) del response_checklist["tool"] response_checklists.append(response_checklist) return response_checklists @@ -1392,25 +1413,29 @@ def video_checklist_inference(prediction_id_mapping): if "checklist" not in feature: continue checklist = feature["checklist"].copy() - checklist.update({ - "answers": [{ - "name": "first_checklist_answer" - }, { - "name": "second_checklist_answer" - }] - }) + checklist.update( + { + "answers": [ + {"name": "first_checklist_answer"}, + {"name": "second_checklist_answer"}, + ] + } + ) checklist.update( - {"frames": [ - { - "start": 7, - "end": 13, - }, - { - "start": 18, - "end": 19, - }, - ]}) + { + "frames": [ + { + "start": 7, + "end": 13, + }, + { + "start": 18, + "end": 19, + }, + ] + } + ) del checklist["tool"] checklists.append(checklist) return checklists @@ -1418,13 +1443,24 @@ def video_checklist_inference(prediction_id_mapping): @pytest.fixture def annotations_by_media_type( - polygon_inference, rectangle_inference, rectangle_inference_document, - line_inference_v2, line_inference, entity_inference, - entity_inference_index, entity_inference_document, - checklist_inference_index, text_inference_index, checklist_inference, - text_inference, video_checklist_inference, prompt_text_inference, - checklist_response_inference, radio_response_inference, - text_response_inference): + polygon_inference, + rectangle_inference, + rectangle_inference_document, + line_inference_v2, + line_inference, + entity_inference, + entity_inference_index, + entity_inference_document, + checklist_inference_index, + text_inference_index, + checklist_inference, + text_inference, + video_checklist_inference, + prompt_text_inference, + checklist_response_inference, + radio_response_inference, + text_response_inference, +): return { MediaType.Audio: [checklist_inference, text_inference], MediaType.Conversational: [ @@ -1450,22 +1486,26 @@ def annotations_by_media_type( MediaType.Text: [checklist_inference, text_inference, entity_inference], MediaType.Video: [video_checklist_inference], MediaType.LLMPromptResponseCreation: [ - prompt_text_inference, text_response_inference, - checklist_response_inference, radio_response_inference + prompt_text_inference, + text_response_inference, + checklist_response_inference, + radio_response_inference, ], MediaType.LLMPromptCreation: [prompt_text_inference], OntologyKind.ResponseCreation: [ - text_response_inference, checklist_response_inference, - radio_response_inference - ] + text_response_inference, + checklist_response_inference, + radio_response_inference, + ], } @pytest.fixture -def model_run_predictions(polygon_inference, rectangle_inference, - line_inference): +def model_run_predictions( + polygon_inference, rectangle_inference, line_inference +): # Not supporting mask since there isn't a signed url representing a seg mask to upload - return (polygon_inference + rectangle_inference + line_inference) + return polygon_inference + rectangle_inference + line_inference @pytest.fixture @@ -1476,17 +1516,28 @@ def object_predictions( entity_inference, segmentation_inference, ): - return (polygon_inference + rectangle_inference + line_inference + - entity_inference + segmentation_inference) + return ( + polygon_inference + + rectangle_inference + + line_inference + + entity_inference + + segmentation_inference + ) @pytest.fixture -def object_predictions_for_annotation_import(polygon_inference, - rectangle_inference, - line_inference, - segmentation_inference): - return (polygon_inference + rectangle_inference + line_inference + - segmentation_inference) +def object_predictions_for_annotation_import( + polygon_inference, + rectangle_inference, + line_inference, + segmentation_inference, +): + return ( + polygon_inference + + rectangle_inference + + line_inference + + segmentation_inference + ) @pytest.fixture @@ -1561,8 +1612,9 @@ def model_run_with_data_rows( model_run_predictions, ) upload_task.wait_until_done() - assert (upload_task.state == AnnotationImportState.FINISHED - ), "Label Import did not finish" + assert ( + upload_task.state == AnnotationImportState.FINISHED + ), "Label Import did not finish" assert ( len(upload_task.errors) == 0 ), f"Label Import {upload_task.name} failed with errors {upload_task.errors}" @@ -1574,12 +1626,16 @@ def model_run_with_data_rows( @pytest.fixture -def model_run_with_all_project_labels(client, configured_project, - model_run_predictions, - model_run: ModelRun, - wait_for_label_processing): +def model_run_with_all_project_labels( + client, + configured_project, + model_run_predictions, + model_run: ModelRun, + wait_for_label_processing, +): use_data_row_ids = list( - set([p["dataRow"]["id"] for p in model_run_predictions])) + set([p["dataRow"]["id"] for p in model_run_predictions]) + ) model_run.upsert_data_rows(use_data_row_ids) @@ -1590,8 +1646,9 @@ def model_run_with_all_project_labels(client, configured_project, model_run_predictions, ) upload_task.wait_until_done() - assert (upload_task.state == AnnotationImportState.FINISHED - ), "Label Import did not finish" + assert ( + upload_task.state == AnnotationImportState.FINISHED + ), "Label Import did not finish" assert ( len(upload_task.errors) == 0 ), f"Label Import {upload_task.name} failed with errors {upload_task.errors}" @@ -1603,7 +1660,6 @@ def model_run_with_all_project_labels(client, configured_project, class AnnotationImportTestHelpers: - @classmethod def assert_file_content(cls, url: str, predictions): response = requests.get(url) @@ -1644,34 +1700,16 @@ def expected_export_v2_image(): exported_annotations = { "objects": [ { - "name": - "polygon", - "value": - "polygon", - "annotation_kind": - "ImagePolygon", + "name": "polygon", + "value": "polygon", + "annotation_kind": "ImagePolygon", "classifications": [], "polygon": [ - { - "x": 147.692, - "y": 118.154 - }, - { - "x": 142.769, - "y": 104.923 - }, - { - "x": 57.846, - "y": 118.769 - }, - { - "x": 28.308, - "y": 169.846 - }, - { - "x": 147.692, - "y": 118.154 - }, + {"x": 147.692, "y": 118.154}, + {"x": 142.769, "y": 104.923}, + {"x": 57.846, "y": 118.769}, + {"x": 28.308, "y": 169.846}, + {"x": 147.692, "y": 118.154}, ], }, { @@ -1687,44 +1725,37 @@ def expected_export_v2_image(): }, }, { - "name": - "polyline", - "value": - "polyline", - "annotation_kind": - "ImagePolyline", + "name": "polyline", + "value": "polyline", + "annotation_kind": "ImagePolyline", "classifications": [], - "line": [{ - "x": 147.692, - "y": 118.154 - }, { - "x": 150.692, - "y": 160.154 - }], + "line": [ + {"x": 147.692, "y": 118.154}, + {"x": 150.692, "y": 160.154}, + ], }, ], "classifications": [ { - "name": - "checklist", - "value": - "checklist", - "checklist_answers": [{ - "name": "first_checklist_answer", - "value": "first_checklist_answer", - "classifications": [] - }, { - "name": "second_checklist_answer", - "value": "second_checklist_answer", - "classifications": [] - }], + "name": "checklist", + "value": "checklist", + "checklist_answers": [ + { + "name": "first_checklist_answer", + "value": "first_checklist_answer", + "classifications": [], + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [], + }, + ], }, { "name": "text", "value": "text", - "text_answer": { - "content": "free form text..." - }, + "text_answer": {"content": "free form text..."}, }, ], "relationships": [], @@ -1738,30 +1769,29 @@ def expected_export_v2_audio(): expected_annotations = { "classifications": [ { - "name": - "checklist", - "value": - "checklist", - "checklist_answers": [{ - "name": "first_checklist_answer", - "value": "first_checklist_answer", - "classifications": [] - }, { - "name": "second_checklist_answer", - "value": "second_checklist_answer", - "classifications": [] - }], + "name": "checklist", + "value": "checklist", + "checklist_answers": [ + { + "name": "first_checklist_answer", + "value": "first_checklist_answer", + "classifications": [], + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [], + }, + ], }, { "name": "text", "value": "text", - "text_answer": { - "content": "free form text..." - }, + "text_answer": {"content": "free form text..."}, }, ], "segments": {}, - "timestamp": {} + "timestamp": {}, } return expected_annotations @@ -1774,24 +1804,23 @@ def expected_export_v2_html(): { "name": "text", "value": "text", - "text_answer": { - "content": "free form text..." - }, + "text_answer": {"content": "free form text..."}, }, { - "name": - "checklist", - "value": - "checklist", - "checklist_answers": [{ - "name": "first_checklist_answer", - "value": "first_checklist_answer", - "classifications": [] - }, { - "name": "second_checklist_answer", - "value": "second_checklist_answer", - "classifications": [] - }], + "name": "checklist", + "value": "checklist", + "checklist_answers": [ + { + "name": "first_checklist_answer", + "value": "first_checklist_answer", + "classifications": [], + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [], + }, + ], }, ], "relationships": [], @@ -1802,39 +1831,40 @@ def expected_export_v2_html(): @pytest.fixture() def expected_export_v2_text(): expected_annotations = { - "objects": [{ - "name": "named-entity", - "value": "named_entity", - "annotation_kind": "TextEntity", - "classifications": [], - 'location': { - 'start': 112, - 'end': 128, - 'token': "research suggests" - }, - }], + "objects": [ + { + "name": "named-entity", + "value": "named_entity", + "annotation_kind": "TextEntity", + "classifications": [], + "location": { + "start": 112, + "end": 128, + "token": "research suggests", + }, + } + ], "classifications": [ { - "name": - "checklist", - "value": - "checklist", - "checklist_answers": [{ - "name": "first_checklist_answer", - "value": "first_checklist_answer", - "classifications": [] - }, { - "name": "second_checklist_answer", - "value": "second_checklist_answer", - "classifications": [] - }], + "name": "checklist", + "value": "checklist", + "checklist_answers": [ + { + "name": "first_checklist_answer", + "value": "first_checklist_answer", + "classifications": [], + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [], + }, + ], }, { "name": "text", "value": "text", - "text_answer": { - "content": "free form text..." - }, + "text_answer": {"content": "free form text..."}, }, ], "relationships": [], @@ -1846,25 +1876,26 @@ def expected_export_v2_text(): def expected_export_v2_video(): expected_annotations = { "frames": {}, - "segments": { - "": [[7, 13], [18, 19]] - }, + "segments": {"": [[7, 13], [18, 19]]}, "key_frame_feature_map": {}, - "classifications": [{ - "name": - "checklist", - "value": - "checklist", - "checklist_answers": [{ - "name": "first_checklist_answer", - "value": "first_checklist_answer", - "classifications": [] - }, { - "name": "second_checklist_answer", - "value": "second_checklist_answer", - "classifications": [] - }], - }], + "classifications": [ + { + "name": "checklist", + "value": "checklist", + "checklist_answers": [ + { + "name": "first_checklist_answer", + "value": "first_checklist_answer", + "classifications": [], + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [], + }, + ], + } + ], } return expected_annotations @@ -1872,44 +1903,41 @@ def expected_export_v2_video(): @pytest.fixture() def expected_export_v2_conversation(): expected_annotations = { - "objects": [{ - "name": "named-entity", - "value": "named_entity", - "annotation_kind": "ConversationalTextEntity", - "classifications": [], - "conversational_location": { - "message_id": "0", - "location": { - "start": 0, - "end": 8 + "objects": [ + { + "name": "named-entity", + "value": "named_entity", + "annotation_kind": "ConversationalTextEntity", + "classifications": [], + "conversational_location": { + "message_id": "0", + "location": {"start": 0, "end": 8}, }, - }, - }], + } + ], "classifications": [ { - "name": - "checklist_index", - "value": - "checklist_index", - "message_id": - "0", - "conversational_checklist_answers": [{ - "name": "first_checklist_answer", - "value": "first_checklist_answer", - "classifications": [] - }, { - "name": "second_checklist_answer", - "value": "second_checklist_answer", - "classifications": [] - }], + "name": "checklist_index", + "value": "checklist_index", + "message_id": "0", + "conversational_checklist_answers": [ + { + "name": "first_checklist_answer", + "value": "first_checklist_answer", + "classifications": [], + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [], + }, + ], }, { "name": "text_index", "value": "text_index", "message_id": "0", - "conversational_text_answer": { - "content": "free form text..." - }, + "conversational_text_answer": {"content": "free form text..."}, }, ], "relationships": [], @@ -1928,22 +1956,13 @@ def expected_export_v2_dicom(): "1": { "objects": { "": { - "name": - "polyline", - "value": - "polyline", - "annotation_kind": - "DICOMPolyline", + "name": "polyline", + "value": "polyline", + "annotation_kind": "DICOMPolyline", "classifications": [], "line": [ - { - "x": 147.692, - "y": 118.154 - }, - { - "x": 150.692, - "y": 160.154 - }, + {"x": 147.692, "y": 118.154}, + {"x": 150.692, "y": 160.154}, ], } }, @@ -1954,30 +1973,18 @@ def expected_export_v2_dicom(): "Sagittal": { "name": "Sagittal", "classifications": [], - "frames": {} - }, - "Coronal": { - "name": "Coronal", - "classifications": [], - "frames": {} + "frames": {}, }, + "Coronal": {"name": "Coronal", "classifications": [], "frames": {}}, }, "segments": { - "Axial": { - "": [[1, 1]] - }, + "Axial": {"": [[1, 1]]}, "Sagittal": {}, - "Coronal": {} + "Coronal": {}, }, "classifications": [], "key_frame_feature_map": { - "": { - "Axial": { - "1": True - }, - "Coronal": {}, - "Sagittal": {} - } + "": {"Axial": {"1": True}, "Coronal": {}, "Sagittal": {}} }, } return expected_annotations @@ -1993,24 +2000,23 @@ def expected_export_v2_document(): "annotation_kind": "DocumentEntityToken", "classifications": [], "location": { - "groups": [{ - "id": - "2f4336f4-a07e-4e0a-a9e1-5629b03b719b", - "page_number": - 1, - "tokens": [ - "3f984bf3-1d61-44f5-b59a-9658a2e3440f", - "3bf00b56-ff12-4e52-8cc1-08dbddb3c3b8", - "6e1c3420-d4b7-4c5a-8fd6-ead43bf73d80", - "87a43d32-af76-4a1d-b262-5c5f4d5ace3a", - "e8606e8a-dfd9-4c49-a635-ad5c879c75d0", - "67c7c19e-4654-425d-bf17-2adb8cf02c30", - "149c5e80-3e07-49a7-ab2d-29ddfe6a38fa", - "b0e94071-2187-461e-8e76-96c58738a52c", - ], - "text": - "Metal-insulator (MI) transitions have been one of the", - }] + "groups": [ + { + "id": "2f4336f4-a07e-4e0a-a9e1-5629b03b719b", + "page_number": 1, + "tokens": [ + "3f984bf3-1d61-44f5-b59a-9658a2e3440f", + "3bf00b56-ff12-4e52-8cc1-08dbddb3c3b8", + "6e1c3420-d4b7-4c5a-8fd6-ead43bf73d80", + "87a43d32-af76-4a1d-b262-5c5f4d5ace3a", + "e8606e8a-dfd9-4c49-a635-ad5c879c75d0", + "67c7c19e-4654-425d-bf17-2adb8cf02c30", + "149c5e80-3e07-49a7-ab2d-29ddfe6a38fa", + "b0e94071-2187-461e-8e76-96c58738a52c", + ], + "text": "Metal-insulator (MI) transitions have been one of the", + } + ] }, }, { @@ -2029,26 +2035,25 @@ def expected_export_v2_document(): ], "classifications": [ { - "name": - "checklist", - "value": - "checklist", - "checklist_answers": [{ - "name": "first_checklist_answer", - "value": "first_checklist_answer", - "classifications": [] - }, { - "name": "second_checklist_answer", - "value": "second_checklist_answer", - "classifications": [] - }], + "name": "checklist", + "value": "checklist", + "checklist_answers": [ + { + "name": "first_checklist_answer", + "value": "first_checklist_answer", + "classifications": [], + }, + { + "name": "second_checklist_answer", + "value": "second_checklist_answer", + "classifications": [], + }, + ], }, { "name": "text", "value": "text", - "text_answer": { - "content": "free form text..." - }, + "text_answer": {"content": "free form text..."}, }, ], "relationships": [], @@ -2064,39 +2069,38 @@ def expected_export_v2_llm_prompt_response_creation(): { "name": "prompt-text", "value": "prompt-text", - "text_answer": { - "content": "free form text..." - }, + "text_answer": {"content": "free form text..."}, }, { - 'name': 'response-text', - 'text_answer': { - 'content': 'free form text...' - }, - 'value': 'response-text' + "name": "response-text", + "text_answer": {"content": "free form text..."}, + "value": "response-text", }, { - 'checklist_answers': [{ - 'classifications': [], - 'name': 'first_checklist_answer', - 'value': 'first_checklist_answer' - }, { - 'classifications': [], - 'name': 'second_checklist_answer', - 'value': 'second_checklist_answer' - }], - 'name': 'checklist-response', - 'value': 'checklist-response' + "checklist_answers": [ + { + "classifications": [], + "name": "first_checklist_answer", + "value": "first_checklist_answer", + }, + { + "classifications": [], + "name": "second_checklist_answer", + "value": "second_checklist_answer", + }, + ], + "name": "checklist-response", + "value": "checklist-response", }, { - 'name': 'radio-response', - 'radio_answer': { - 'classifications': [], - 'name': 'first_radio_answer', - 'value': 'first_radio_answer' + "name": "radio-response", + "radio_answer": { + "classifications": [], + "name": "first_radio_answer", + "value": "first_radio_answer", }, - 'name': 'radio-response', - 'value': 'radio-response' + "name": "radio-response", + "value": "radio-response", }, ], "relationships": [], @@ -2108,13 +2112,13 @@ def expected_export_v2_llm_prompt_response_creation(): def expected_export_v2_llm_prompt_creation(): expected_annotations = { "objects": [], - "classifications": [{ - "name": "prompt-text", - "value": "prompt-text", - "text_answer": { - "content": "free form text..." + "classifications": [ + { + "name": "prompt-text", + "value": "prompt-text", + "text_answer": {"content": "free form text..."}, }, - },], + ], "relationships": [], } return expected_annotations @@ -2123,38 +2127,39 @@ def expected_export_v2_llm_prompt_creation(): @pytest.fixture() def expected_export_v2_llm_response_creation(): expected_annotations = { - 'objects': [], - 'relationships': [], + "objects": [], + "relationships": [], "classifications": [ { - 'name': 'response-text', - 'text_answer': { - 'content': 'free form text...' - }, - 'value': 'response-text' + "name": "response-text", + "text_answer": {"content": "free form text..."}, + "value": "response-text", }, { - 'checklist_answers': [{ - 'classifications': [], - 'name': 'first_checklist_answer', - 'value': 'first_checklist_answer' - }, { - 'classifications': [], - 'name': 'second_checklist_answer', - 'value': 'second_checklist_answer' - }], - 'name': 'checklist-response', - 'value': 'checklist-response' + "checklist_answers": [ + { + "classifications": [], + "name": "first_checklist_answer", + "value": "first_checklist_answer", + }, + { + "classifications": [], + "name": "second_checklist_answer", + "value": "second_checklist_answer", + }, + ], + "name": "checklist-response", + "value": "checklist-response", }, { - 'name': 'radio-response', - 'radio_answer': { - 'classifications': [], - 'name': 'first_radio_answer', - 'value': 'first_radio_answer' + "name": "radio-response", + "radio_answer": { + "classifications": [], + "name": "first_radio_answer", + "value": "first_radio_answer", }, - 'name': 'radio-response', - 'value': 'radio-response' + "name": "radio-response", + "value": "radio-response", }, ], } @@ -2162,43 +2167,35 @@ def expected_export_v2_llm_response_creation(): @pytest.fixture -def exports_v2_by_media_type(expected_export_v2_image, expected_export_v2_audio, - expected_export_v2_html, expected_export_v2_text, - expected_export_v2_video, - expected_export_v2_conversation, - expected_export_v2_dicom, - expected_export_v2_document, - expected_export_v2_llm_prompt_response_creation, - expected_export_v2_llm_prompt_creation, - expected_export_v2_llm_response_creation): +def exports_v2_by_media_type( + expected_export_v2_image, + expected_export_v2_audio, + expected_export_v2_html, + expected_export_v2_text, + expected_export_v2_video, + expected_export_v2_conversation, + expected_export_v2_dicom, + expected_export_v2_document, + expected_export_v2_llm_prompt_response_creation, + expected_export_v2_llm_prompt_creation, + expected_export_v2_llm_response_creation, +): return { - MediaType.Image: - expected_export_v2_image, - MediaType.Audio: - expected_export_v2_audio, - MediaType.Html: - expected_export_v2_html, - MediaType.Text: - expected_export_v2_text, - MediaType.Video: - expected_export_v2_video, - MediaType.Conversational: - expected_export_v2_conversation, - MediaType.Dicom: - expected_export_v2_dicom, - MediaType.Document: - expected_export_v2_document, - MediaType.LLMPromptResponseCreation: - expected_export_v2_llm_prompt_response_creation, - MediaType.LLMPromptCreation: - expected_export_v2_llm_prompt_creation, - OntologyKind.ResponseCreation: - expected_export_v2_llm_response_creation + MediaType.Image: expected_export_v2_image, + MediaType.Audio: expected_export_v2_audio, + MediaType.Html: expected_export_v2_html, + MediaType.Text: expected_export_v2_text, + MediaType.Video: expected_export_v2_video, + MediaType.Conversational: expected_export_v2_conversation, + MediaType.Dicom: expected_export_v2_dicom, + MediaType.Document: expected_export_v2_document, + MediaType.LLMPromptResponseCreation: expected_export_v2_llm_prompt_response_creation, + MediaType.LLMPromptCreation: expected_export_v2_llm_prompt_creation, + OntologyKind.ResponseCreation: expected_export_v2_llm_response_creation, } class Helpers: - @staticmethod def remove_keys_recursive(d, keys): for k in keys: @@ -2230,7 +2227,6 @@ def rename_cuid_key_recursive(d): @staticmethod def set_project_media_type_from_data_type(project, data_type_class): - def to_pascal_case(name: str) -> str: return "".join([word.capitalize() for word in name.split("_")]) @@ -2250,7 +2246,7 @@ def to_pascal_case(name: str) -> str: @staticmethod def find_data_row_filter(data_row): - return lambda dr: dr['data_row']['id'] == data_row.uid + return lambda dr: dr["data_row"]["id"] == data_row.uid @pytest.fixture diff --git a/libs/labelbox/tests/data/annotation_import/test_annotation_import_limit.py b/libs/labelbox/tests/data/annotation_import/test_annotation_import_limit.py index 297f45c52..dec20fbb5 100644 --- a/libs/labelbox/tests/data/annotation_import/test_annotation_import_limit.py +++ b/libs/labelbox/tests/data/annotation_import/test_annotation_import_limit.py @@ -1,33 +1,56 @@ import itertools import uuid -from labelbox.schema.annotation_import import AnnotationImport, MALPredictionImport +from labelbox.schema.annotation_import import ( + AnnotationImport, + MALPredictionImport, +) from labelbox.schema.media_type import MediaType import pytest from unittest.mock import patch -@patch('labelbox.schema.annotation_import.ANNOTATION_PER_LABEL_LIMIT', 1) -def test_above_annotation_limit_on_single_import_on_single_data_row(annotations_by_media_type): - - annotations_ndjson = list(itertools.chain.from_iterable(annotations_by_media_type[MediaType.Image])) +@patch("labelbox.schema.annotation_import.ANNOTATION_PER_LABEL_LIMIT", 1) +def test_above_annotation_limit_on_single_import_on_single_data_row( + annotations_by_media_type, +): + annotations_ndjson = list( + itertools.chain.from_iterable( + annotations_by_media_type[MediaType.Image] + ) + ) data_row_id = annotations_ndjson[0]["dataRow"]["id"] - data_row_annotations = [annotation for annotation in annotations_ndjson if annotation["dataRow"]["id"] == data_row_id and "bbox" in annotation] - - with pytest.raises(ValueError): - AnnotationImport._validate_data_rows([data_row_annotations[0]]*2) + data_row_annotations = [ + annotation + for annotation in annotations_ndjson + if annotation["dataRow"]["id"] == data_row_id and "bbox" in annotation + ] + with pytest.raises(ValueError): + AnnotationImport._validate_data_rows([data_row_annotations[0]] * 2) -@patch('labelbox.schema.annotation_import.ANNOTATION_PER_LABEL_LIMIT', 1) -def test_above_annotation_limit_divided_among_different_rows(annotations_by_media_type): - annotations_ndjson = list(itertools.chain.from_iterable(annotations_by_media_type[MediaType.Image])) +@patch("labelbox.schema.annotation_import.ANNOTATION_PER_LABEL_LIMIT", 1) +def test_above_annotation_limit_divided_among_different_rows( + annotations_by_media_type, +): + annotations_ndjson = list( + itertools.chain.from_iterable( + annotations_by_media_type[MediaType.Image] + ) + ) data_row_id = annotations_ndjson[0]["dataRow"]["id"] - - first_data_row_annotation = [annotation for annotation in annotations_ndjson if annotation["dataRow"]["id"] == data_row_id and "bbox" in annotation][0] - + + first_data_row_annotation = [ + annotation + for annotation in annotations_ndjson + if annotation["dataRow"]["id"] == data_row_id and "bbox" in annotation + ][0] + second_data_row_annotation = first_data_row_annotation.copy() second_data_row_annotation["dataRow"]["id"] == "data_row_id_2" - + with pytest.raises(ValueError): - AnnotationImport._validate_data_rows([first_data_row_annotation, second_data_row_annotation]*2) \ No newline at end of file + AnnotationImport._validate_data_rows( + [first_data_row_annotation, second_data_row_annotation] * 2 + ) diff --git a/libs/labelbox/tests/data/annotation_import/test_bulk_import_request.py b/libs/labelbox/tests/data/annotation_import/test_bulk_import_request.py index 9e9abd47f..9abae1422 100644 --- a/libs/labelbox/tests/data/annotation_import/test_bulk_import_request.py +++ b/libs/labelbox/tests/data/annotation_import/test_bulk_import_request.py @@ -1,17 +1,30 @@ from unittest.mock import patch import uuid from labelbox import parser, Project -from labelbox.data.annotation_types.data.generic_data_row_data import GenericDataRowData +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) import pytest import random from labelbox.data.annotation_types.annotation import ObjectAnnotation -from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnnotation, ClassificationAnswer, Radio +from labelbox.data.annotation_types.classification.classification import ( + Checklist, + ClassificationAnnotation, + ClassificationAnswer, + Radio, +) from labelbox.data.annotation_types.data.video import VideoData from labelbox.data.annotation_types.geometry.point import Point -from labelbox.data.annotation_types.geometry.rectangle import Rectangle, RectangleUnit +from labelbox.data.annotation_types.geometry.rectangle import ( + Rectangle, + RectangleUnit, +) from labelbox.data.annotation_types.label import Label from labelbox.data.annotation_types.data.text import TextData -from labelbox.data.annotation_types.ner import DocumentEntity, DocumentTextSelection +from labelbox.data.annotation_types.ner import ( + DocumentEntity, + DocumentTextSelection, +) from labelbox.data.annotation_types.video import VideoObjectAnnotation from labelbox.data.serialization import NDJsonConverter @@ -20,20 +33,22 @@ from labelbox.schema.enums import BulkImportRequestState from labelbox.schema.annotation_import import LabelImport, MALPredictionImport from labelbox.schema.media_type import MediaType + """ - Here we only want to check that the uploads are calling the validation - Then with unit tests we can check the types of errors raised """ -#TODO: remove library once bulk import requests are removed +# TODO: remove library once bulk import requests are removed + @pytest.mark.order(1) def test_create_from_url(module_project): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" - bulk_import_request = module_project.upload_annotations(name=name, - annotations=url, - validate=False) + bulk_import_request = module_project.upload_annotations( + name=name, annotations=url, validate=False + ) assert bulk_import_request.project() == module_project assert bulk_import_request.name == name @@ -47,18 +62,20 @@ def test_validate_file(module_project): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" with pytest.raises(MALValidationError): - module_project.upload_annotations(name=name, - annotations=url, - validate=True) - #Schema ids shouldn't match + module_project.upload_annotations( + name=name, annotations=url, validate=True + ) + # Schema ids shouldn't match -def test_create_from_objects(module_project: Project, predictions, - annotation_import_test_helpers): +def test_create_from_objects( + module_project: Project, predictions, annotation_import_test_helpers +): name = str(uuid.uuid4()) bulk_import_request = module_project.upload_annotations( - name=name, annotations=predictions) + name=name, annotations=predictions + ) assert bulk_import_request.project() == module_project assert bulk_import_request.name == name @@ -66,16 +83,19 @@ def test_create_from_objects(module_project: Project, predictions, assert bulk_import_request.status_file_url is None assert bulk_import_request.state == BulkImportRequestState.RUNNING annotation_import_test_helpers.assert_file_content( - bulk_import_request.input_file_url, predictions) + bulk_import_request.input_file_url, predictions + ) -def test_create_from_label_objects(module_project, predictions, - annotation_import_test_helpers): +def test_create_from_label_objects( + module_project, predictions, annotation_import_test_helpers +): name = str(uuid.uuid4()) labels = list(NDJsonConverter.deserialize(predictions)) bulk_import_request = module_project.upload_annotations( - name=name, annotations=labels) + name=name, annotations=labels + ) assert bulk_import_request.project() == module_project assert bulk_import_request.name == name @@ -84,11 +104,13 @@ def test_create_from_label_objects(module_project, predictions, assert bulk_import_request.state == BulkImportRequestState.RUNNING normalized_predictions = list(NDJsonConverter.serialize(labels)) annotation_import_test_helpers.assert_file_content( - bulk_import_request.input_file_url, normalized_predictions) + bulk_import_request.input_file_url, normalized_predictions + ) -def test_create_from_local_file(tmp_path, predictions, module_project, - annotation_import_test_helpers): +def test_create_from_local_file( + tmp_path, predictions, module_project, annotation_import_test_helpers +): name = str(uuid.uuid4()) file_name = f"{name}.ndjson" file_path = tmp_path / file_name @@ -96,7 +118,8 @@ def test_create_from_local_file(tmp_path, predictions, module_project, parser.dump(predictions, f) bulk_import_request = module_project.upload_annotations( - name=name, annotations=str(file_path), validate=False) + name=name, annotations=str(file_path), validate=False + ) assert bulk_import_request.project() == module_project assert bulk_import_request.name == name @@ -104,18 +127,20 @@ def test_create_from_local_file(tmp_path, predictions, module_project, assert bulk_import_request.status_file_url is None assert bulk_import_request.state == BulkImportRequestState.RUNNING annotation_import_test_helpers.assert_file_content( - bulk_import_request.input_file_url, predictions) + bulk_import_request.input_file_url, predictions + ) def test_get(client, module_project): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" - module_project.upload_annotations(name=name, - annotations=url, - validate=False) + module_project.upload_annotations( + name=name, annotations=url, validate=False + ) bulk_import_request = BulkImportRequest.from_name( - client, project_id=module_project.uid, name=name) + client, project_id=module_project.uid, name=name + ) assert bulk_import_request.project() == module_project assert bulk_import_request.name == name @@ -133,7 +158,8 @@ def test_validate_ndjson(tmp_path, module_project): with pytest.raises(ValueError): module_project.upload_annotations( - name="name", validate=True, annotations=str(file_path)) + name="name", validate=True, annotations=str(file_path) + ) def test_validate_ndjson_uuid(tmp_path, module_project, predictions): @@ -141,31 +167,34 @@ def test_validate_ndjson_uuid(tmp_path, module_project, predictions): file_path = tmp_path / file_name repeat_uuid = predictions.copy() uid = str(uuid.uuid4()) - repeat_uuid[0]['uuid'] = uid - repeat_uuid[1]['uuid'] = uid + repeat_uuid[0]["uuid"] = uid + repeat_uuid[1]["uuid"] = uid with file_path.open("w") as f: parser.dump(repeat_uuid, f) with pytest.raises(UuidError): - module_project.upload_annotations(name="name", - validate=True, - annotations=str(file_path)) + module_project.upload_annotations( + name="name", validate=True, annotations=str(file_path) + ) with pytest.raises(UuidError): - module_project.upload_annotations(name="name", - validate=True, - annotations=repeat_uuid) + module_project.upload_annotations( + name="name", validate=True, annotations=repeat_uuid + ) -@pytest.mark.skip("Slow test and uses a deprecated api endpoint for annotation imports") -def test_wait_till_done(rectangle_inference, - project): +@pytest.mark.skip( + "Slow test and uses a deprecated api endpoint for annotation imports" +) +def test_wait_till_done(rectangle_inference, project): name = str(uuid.uuid4()) url = project.client.upload_data( - content=parser.dumps(rectangle_inference), sign=True) + content=parser.dumps(rectangle_inference), sign=True + ) bulk_import_request = project.upload_annotations( - name=name, annotations=url, validate=False) + name=name, annotations=url, validate=False + ) assert len(bulk_import_request.inputs) == 1 bulk_import_request.wait_until_done() @@ -174,11 +203,12 @@ def test_wait_till_done(rectangle_inference, # Check that the status files are being returned as expected assert len(bulk_import_request.errors) == 0 assert len(bulk_import_request.inputs) == 1 - assert bulk_import_request.inputs[0]['uuid'] == rectangle_inference['uuid'] + assert bulk_import_request.inputs[0]["uuid"] == rectangle_inference["uuid"] assert len(bulk_import_request.statuses) == 1 - assert bulk_import_request.statuses[0]['status'] == 'SUCCESS' - assert bulk_import_request.statuses[0]['uuid'] == rectangle_inference[ - 'uuid'] + assert bulk_import_request.statuses[0]["status"] == "SUCCESS" + assert ( + bulk_import_request.statuses[0]["uuid"] == rectangle_inference["uuid"] + ) def test_project_bulk_import_requests(module_project, predictions): @@ -187,17 +217,20 @@ def test_project_bulk_import_requests(module_project, predictions): name = str(uuid.uuid4()) bulk_import_request = module_project.upload_annotations( - name=name, annotations=predictions) + name=name, annotations=predictions + ) bulk_import_request.wait_until_done() name = str(uuid.uuid4()) bulk_import_request = module_project.upload_annotations( - name=name, annotations=predictions) + name=name, annotations=predictions + ) bulk_import_request.wait_until_done() name = str(uuid.uuid4()) bulk_import_request = module_project.upload_annotations( - name=name, annotations=predictions) + name=name, annotations=predictions + ) bulk_import_request.wait_until_done() result = module_project.bulk_import_requests() @@ -206,12 +239,16 @@ def test_project_bulk_import_requests(module_project, predictions): def test_delete(module_project, predictions): name = str(uuid.uuid4()) - + bulk_import_requests = module_project.bulk_import_requests() - [bulk_import_request.delete() for bulk_import_request in bulk_import_requests] - + [ + bulk_import_request.delete() + for bulk_import_request in bulk_import_requests + ] + bulk_import_request = module_project.upload_annotations( - name=name, annotations=predictions) + name=name, annotations=predictions + ) bulk_import_request.wait_until_done() all_import_requests = module_project.bulk_import_requests() assert len(list(all_import_requests)) == 1 diff --git a/libs/labelbox/tests/data/annotation_import/test_data_types.py b/libs/labelbox/tests/data/annotation_import/test_data_types.py index d7b3ef825..1e45295ef 100644 --- a/libs/labelbox/tests/data/annotation_import/test_data_types.py +++ b/libs/labelbox/tests/data/annotation_import/test_data_types.py @@ -37,13 +37,15 @@ def test_data_row_type_by_data_row_id( annotations_by_media_type, hardcoded_datarow_id, ): - annotations_ndjson = annotations_by_media_type[media_type] + annotations_ndjson = annotations_by_media_type[media_type] annotations_ndjson = [annotation[0] for annotation in annotations_ndjson] - + label = list(NDJsonConverter.deserialize(annotations_ndjson))[0] - - data_label = lb_types.Label(data=data_type_class(uid = hardcoded_datarow_id()), - annotations=label.annotations) + + data_label = lb_types.Label( + data=data_type_class(uid=hardcoded_datarow_id()), + annotations=label.annotations, + ) assert data_label.data.uid == label.data.uid assert label.annotations == data_label.annotations @@ -67,13 +69,15 @@ def test_data_row_type_by_global_key( annotations_by_media_type, hardcoded_global_key, ): - annotations_ndjson = annotations_by_media_type[media_type] + annotations_ndjson = annotations_by_media_type[media_type] annotations_ndjson = [annotation[0] for annotation in annotations_ndjson] - + label = list(NDJsonConverter.deserialize(annotations_ndjson))[0] - - data_label = lb_types.Label(data=data_type_class(global_key = hardcoded_global_key()), - annotations=label.annotations) + + data_label = lb_types.Label( + data=data_type_class(global_key=hardcoded_global_key()), + annotations=label.annotations, + ) assert data_label.data.global_key == label.data.global_key assert label.annotations == data_label.annotations diff --git a/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py b/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py index fa2c9e3f8..f8f0c449a 100644 --- a/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py +++ b/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py @@ -1,5 +1,7 @@ import datetime -from labelbox.data.annotation_types.data.generic_data_row_data import GenericDataRowData +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) from labelbox.data.serialization.ndjson.converter import NDJsonConverter from labelbox.data.annotation_types import Label import pytest @@ -10,6 +12,7 @@ from labelbox.schema.annotation_import import AnnotationImportState from labelbox import Project, Client, OntologyKind import itertools + """ - integration test for importing mal labels and ground truths with each supported MediaType. - NDJSON is used to generate annotations. @@ -18,7 +21,8 @@ def validate_iso_format(date_string: str): parsed_t = datetime.datetime.fromisoformat( - date_string) # this will blow up if the string is not in iso format + date_string + ) # this will blow up if the string is not in iso format assert parsed_t.hour is not None assert parsed_t.minute is not None assert parsed_t.second is not None @@ -26,16 +30,18 @@ def validate_iso_format(date_string: str): @pytest.mark.parametrize( "media_type, data_type_class", - [(MediaType.Audio, GenericDataRowData), - (MediaType.Html, GenericDataRowData), - (MediaType.Image, GenericDataRowData), - (MediaType.Text, GenericDataRowData), - (MediaType.Video, GenericDataRowData), - (MediaType.Conversational, GenericDataRowData), - (MediaType.Document, GenericDataRowData), - (MediaType.LLMPromptResponseCreation, GenericDataRowData), - (MediaType.LLMPromptCreation, GenericDataRowData), - (OntologyKind.ResponseCreation, GenericDataRowData)], + [ + (MediaType.Audio, GenericDataRowData), + (MediaType.Html, GenericDataRowData), + (MediaType.Image, GenericDataRowData), + (MediaType.Text, GenericDataRowData), + (MediaType.Video, GenericDataRowData), + (MediaType.Conversational, GenericDataRowData), + (MediaType.Document, GenericDataRowData), + (MediaType.LLMPromptResponseCreation, GenericDataRowData), + (MediaType.LLMPromptCreation, GenericDataRowData), + (OntologyKind.ResponseCreation, GenericDataRowData), + ], ) def test_generic_data_row_type_by_data_row_id( media_type, @@ -48,8 +54,10 @@ def test_generic_data_row_type_by_data_row_id( label = list(NDJsonConverter.deserialize(annotations_ndjson))[0] - data_label = Label(data=data_type_class(uid=hardcoded_datarow_id()), - annotations=label.annotations) + data_label = Label( + data=data_type_class(uid=hardcoded_datarow_id()), + annotations=label.annotations, + ) assert data_label.data.uid == label.data.uid assert label.annotations == data_label.annotations @@ -67,7 +75,7 @@ def test_generic_data_row_type_by_data_row_id( (MediaType.Document, GenericDataRowData), # (MediaType.LLMPromptResponseCreation, GenericDataRowData), # (MediaType.LLMPromptCreation, GenericDataRowData), - (OntologyKind.ResponseCreation, GenericDataRowData) + (OntologyKind.ResponseCreation, GenericDataRowData), ], ) def test_generic_data_row_type_by_global_key( @@ -81,8 +89,10 @@ def test_generic_data_row_type_by_global_key( label = list(NDJsonConverter.deserialize(annotations_ndjson))[0] - data_label = Label(data=data_type_class(global_key=hardcoded_global_key()), - annotations=label.annotations) + data_label = Label( + data=data_type_class(global_key=hardcoded_global_key()), + annotations=label.annotations, + ) assert data_label.data.global_key == label.data.global_key assert label.annotations == data_label.annotations @@ -90,16 +100,24 @@ def test_generic_data_row_type_by_global_key( @pytest.mark.parametrize( "configured_project, media_type", - [(MediaType.Audio, MediaType.Audio), (MediaType.Html, MediaType.Html), - (MediaType.Image, MediaType.Image), (MediaType.Text, MediaType.Text), - (MediaType.Video, MediaType.Video), - (MediaType.Conversational, MediaType.Conversational), - (MediaType.Document, MediaType.Document), - (MediaType.Dicom, MediaType.Dicom), - (MediaType.LLMPromptResponseCreation, MediaType.LLMPromptResponseCreation), - (MediaType.LLMPromptCreation, MediaType.LLMPromptCreation), - (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation)], - indirect=["configured_project"]) + [ + (MediaType.Audio, MediaType.Audio), + (MediaType.Html, MediaType.Html), + (MediaType.Image, MediaType.Image), + (MediaType.Text, MediaType.Text), + (MediaType.Video, MediaType.Video), + (MediaType.Conversational, MediaType.Conversational), + (MediaType.Document, MediaType.Document), + (MediaType.Dicom, MediaType.Dicom), + ( + MediaType.LLMPromptResponseCreation, + MediaType.LLMPromptResponseCreation, + ), + (MediaType.LLMPromptCreation, MediaType.LLMPromptCreation), + (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation), + ], + indirect=["configured_project"], +) def test_import_media_types( client: Client, configured_project: Project, @@ -110,18 +128,23 @@ def test_import_media_types( media_type, ): annotations_ndjson = list( - itertools.chain.from_iterable(annotations_by_media_type[media_type])) + itertools.chain.from_iterable(annotations_by_media_type[media_type]) + ) label_import = lb.LabelImport.create_from_objects( - client, configured_project.uid, f"test-import-{media_type}", - annotations_ndjson) + client, + configured_project.uid, + f"test-import-{media_type}", + annotations_ndjson, + ) label_import.wait_until_done() assert label_import.state == AnnotationImportState.FINISHED assert len(label_import.errors) == 0 result = export_v2_test_helpers.run_project_export_v2_task( - configured_project) + configured_project + ) assert result @@ -129,20 +152,28 @@ def test_import_media_types( # timestamp fields are in iso format validate_iso_format(exported_data["data_row"]["details"]["created_at"]) validate_iso_format(exported_data["data_row"]["details"]["updated_at"]) - validate_iso_format(exported_data["projects"][configured_project.uid] - ["labels"][0]["label_details"]["created_at"]) - validate_iso_format(exported_data["projects"][configured_project.uid] - ["labels"][0]["label_details"]["updated_at"]) - - assert exported_data["data_row"][ - "id"] in configured_project.data_row_ids + validate_iso_format( + exported_data["projects"][configured_project.uid]["labels"][0][ + "label_details" + ]["created_at"] + ) + validate_iso_format( + exported_data["projects"][configured_project.uid]["labels"][0][ + "label_details" + ]["updated_at"] + ) + + assert ( + exported_data["data_row"]["id"] in configured_project.data_row_ids + ) exported_project = exported_data["projects"][configured_project.uid] exported_project_labels = exported_project["labels"][0] exported_annotations = exported_project_labels["annotations"] expected_data = exports_v2_by_media_type[media_type] - helpers.remove_keys_recursive(exported_annotations, - ["feature_id", "feature_schema_id"]) + helpers.remove_keys_recursive( + exported_annotations, ["feature_id", "feature_schema_id"] + ) helpers.rename_cuid_key_recursive(exported_annotations) assert exported_annotations == expected_data @@ -150,30 +181,46 @@ def test_import_media_types( @pytest.mark.parametrize( "configured_project_by_global_key, media_type", - [(MediaType.Audio, MediaType.Audio), (MediaType.Html, MediaType.Html), - (MediaType.Image, MediaType.Image), (MediaType.Text, MediaType.Text), - (MediaType.Video, MediaType.Video), - (MediaType.Conversational, MediaType.Conversational), - (MediaType.Document, MediaType.Document), - (MediaType.Dicom, MediaType.Dicom), - (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation)], - indirect=["configured_project_by_global_key"]) + [ + (MediaType.Audio, MediaType.Audio), + (MediaType.Html, MediaType.Html), + (MediaType.Image, MediaType.Image), + (MediaType.Text, MediaType.Text), + (MediaType.Video, MediaType.Video), + (MediaType.Conversational, MediaType.Conversational), + (MediaType.Document, MediaType.Document), + (MediaType.Dicom, MediaType.Dicom), + (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation), + ], + indirect=["configured_project_by_global_key"], +) def test_import_media_types_by_global_key( - client, configured_project_by_global_key, annotations_by_media_type, - exports_v2_by_media_type, export_v2_test_helpers, helpers, media_type): + client, + configured_project_by_global_key, + annotations_by_media_type, + exports_v2_by_media_type, + export_v2_test_helpers, + helpers, + media_type, +): annotations_ndjson = list( - itertools.chain.from_iterable(annotations_by_media_type[media_type])) + itertools.chain.from_iterable(annotations_by_media_type[media_type]) + ) label_import = lb.LabelImport.create_from_objects( - client, configured_project_by_global_key.uid, - f"test-import-{media_type}", annotations_ndjson) + client, + configured_project_by_global_key.uid, + f"test-import-{media_type}", + annotations_ndjson, + ) label_import.wait_until_done() assert label_import.state == AnnotationImportState.FINISHED assert len(label_import.errors) == 0 result = export_v2_test_helpers.run_project_export_v2_task( - configured_project_by_global_key) + configured_project_by_global_key + ) assert result @@ -182,22 +229,30 @@ def test_import_media_types_by_global_key( validate_iso_format(exported_data["data_row"]["details"]["created_at"]) validate_iso_format(exported_data["data_row"]["details"]["updated_at"]) validate_iso_format( - exported_data["projects"][configured_project_by_global_key.uid] - ["labels"][0]["label_details"]["created_at"]) + exported_data["projects"][configured_project_by_global_key.uid][ + "labels" + ][0]["label_details"]["created_at"] + ) validate_iso_format( - exported_data["projects"][configured_project_by_global_key.uid] - ["labels"][0]["label_details"]["updated_at"]) - - assert exported_data["data_row"][ - "id"] in configured_project_by_global_key.data_row_ids + exported_data["projects"][configured_project_by_global_key.uid][ + "labels" + ][0]["label_details"]["updated_at"] + ) + + assert ( + exported_data["data_row"]["id"] + in configured_project_by_global_key.data_row_ids + ) exported_project = exported_data["projects"][ - configured_project_by_global_key.uid] + configured_project_by_global_key.uid + ] exported_project_labels = exported_project["labels"][0] exported_annotations = exported_project_labels["annotations"] expected_data = exports_v2_by_media_type[media_type] - helpers.remove_keys_recursive(exported_annotations, - ["feature_id", "feature_schema_id"]) + helpers.remove_keys_recursive( + exported_annotations, ["feature_id", "feature_schema_id"] + ) helpers.rename_cuid_key_recursive(exported_annotations) assert exported_annotations == expected_data @@ -214,15 +269,21 @@ def test_import_media_types_by_global_key( (MediaType.Conversational, MediaType.Conversational), (MediaType.Document, MediaType.Document), (MediaType.Dicom, MediaType.Dicom), - (MediaType.LLMPromptResponseCreation, MediaType.LLMPromptResponseCreation), + ( + MediaType.LLMPromptResponseCreation, + MediaType.LLMPromptResponseCreation, + ), (MediaType.LLMPromptCreation, MediaType.LLMPromptCreation), - (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation) + (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation), ], - indirect=["configured_project"]) -def test_import_mal_annotations(client, configured_project: Project, - annotations_by_media_type, media_type): + indirect=["configured_project"], +) +def test_import_mal_annotations( + client, configured_project: Project, annotations_by_media_type, media_type +): annotations_ndjson = list( - itertools.chain.from_iterable(annotations_by_media_type[media_type])) + itertools.chain.from_iterable(annotations_by_media_type[media_type]) + ) import_annotations = lb.MALPredictionImport.create_from_objects( client=client, @@ -238,20 +299,28 @@ def test_import_mal_annotations(client, configured_project: Project, @pytest.mark.parametrize( "configured_project_by_global_key, media_type", - [(MediaType.Audio, MediaType.Audio), (MediaType.Html, MediaType.Html), - (MediaType.Image, MediaType.Image), (MediaType.Text, MediaType.Text), - (MediaType.Video, MediaType.Video), - (MediaType.Conversational, MediaType.Conversational), - (MediaType.Document, MediaType.Document), - (MediaType.Dicom, MediaType.Dicom), - (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation)], - indirect=["configured_project_by_global_key"]) + [ + (MediaType.Audio, MediaType.Audio), + (MediaType.Html, MediaType.Html), + (MediaType.Image, MediaType.Image), + (MediaType.Text, MediaType.Text), + (MediaType.Video, MediaType.Video), + (MediaType.Conversational, MediaType.Conversational), + (MediaType.Document, MediaType.Document), + (MediaType.Dicom, MediaType.Dicom), + (OntologyKind.ResponseCreation, OntologyKind.ResponseCreation), + ], + indirect=["configured_project_by_global_key"], +) def test_import_mal_annotations_global_key( - client, configured_project_by_global_key: Project, - annotations_by_media_type, media_type): - + client, + configured_project_by_global_key: Project, + annotations_by_media_type, + media_type, +): annotations_ndjson = list( - itertools.chain.from_iterable(annotations_by_media_type[media_type])) + itertools.chain.from_iterable(annotations_by_media_type[media_type]) + ) import_annotations = lb.MALPredictionImport.create_from_objects( client=client, diff --git a/libs/labelbox/tests/data/annotation_import/test_label_import.py b/libs/labelbox/tests/data/annotation_import/test_label_import.py index 50b701813..5576025fd 100644 --- a/libs/labelbox/tests/data/annotation_import/test_label_import.py +++ b/libs/labelbox/tests/data/annotation_import/test_label_import.py @@ -3,6 +3,7 @@ from labelbox import parser from labelbox.schema.annotation_import import AnnotationImportState, LabelImport + """ - Here we only want to check that the uploads are calling the validation - Then with unit tests we can check the types of errors raised @@ -10,50 +11,53 @@ """ -def test_create_with_url_arg(client, module_project, - annotation_import_test_helpers): +def test_create_with_url_arg( + client, module_project, annotation_import_test_helpers +): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" label_import = LabelImport.create( - client=client, - id=module_project.uid, - name=name, - url=url) + client=client, id=module_project.uid, name=name, url=url + ) assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name, url) -def test_create_from_url(client, module_project, - annotation_import_test_helpers): +def test_create_from_url( + client, module_project, annotation_import_test_helpers +): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" label_import = LabelImport.create_from_url( - client=client, - project_id=module_project.uid, - name=name, - url=url) + client=client, project_id=module_project.uid, name=name, url=url + ) assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name, url) -def test_create_with_labels_arg(client, module_project, object_predictions, - annotation_import_test_helpers): +def test_create_with_labels_arg( + client, module_project, object_predictions, annotation_import_test_helpers +): """this test should check running state only to validate running, not completed""" name = str(uuid.uuid4()) - label_import = LabelImport.create(client=client, - id=module_project.uid, - name=name, - labels=object_predictions) + label_import = LabelImport.create( + client=client, + id=module_project.uid, + name=name, + labels=object_predictions, + ) assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name) annotation_import_test_helpers.assert_file_content( - label_import.input_file_url, object_predictions) + label_import.input_file_url, object_predictions + ) -def test_create_from_objects(client, module_project, object_predictions, - annotation_import_test_helpers): +def test_create_from_objects( + client, module_project, object_predictions, annotation_import_test_helpers +): """this test should check running state only to validate running, not completed""" name = str(uuid.uuid4()) @@ -61,16 +65,23 @@ def test_create_from_objects(client, module_project, object_predictions, client=client, project_id=module_project.uid, name=name, - labels=object_predictions) + labels=object_predictions, + ) assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name) annotation_import_test_helpers.assert_file_content( - label_import.input_file_url, object_predictions) + label_import.input_file_url, object_predictions + ) -def test_create_with_path_arg(client, tmp_path, module_project, object_predictions, - annotation_import_test_helpers): +def test_create_with_path_arg( + client, + tmp_path, + module_project, + object_predictions, + annotation_import_test_helpers, +): project = module_project name = str(uuid.uuid4()) file_name = f"{name}.ndjson" @@ -78,19 +89,24 @@ def test_create_with_path_arg(client, tmp_path, module_project, object_predictio with file_path.open("w") as f: parser.dump(object_predictions, f) - label_import = LabelImport.create(client=client, - id=project.uid, - name=name, - path=str(file_path)) + label_import = LabelImport.create( + client=client, id=project.uid, name=name, path=str(file_path) + ) assert label_import.parent_id == project.uid annotation_import_test_helpers.check_running_state(label_import, name) annotation_import_test_helpers.assert_file_content( - label_import.input_file_url, object_predictions) + label_import.input_file_url, object_predictions + ) -def test_create_from_local_file(client, tmp_path, module_project, object_predictions, - annotation_import_test_helpers): +def test_create_from_local_file( + client, + tmp_path, + module_project, + object_predictions, + annotation_import_test_helpers, +): project = module_project name = str(uuid.uuid4()) file_name = f"{name}.ndjson" @@ -98,26 +114,23 @@ def test_create_from_local_file(client, tmp_path, module_project, object_predict with file_path.open("w") as f: parser.dump(object_predictions, f) - label_import = LabelImport.create_from_file(client=client, - project_id=project.uid, - name=name, - path=str(file_path)) + label_import = LabelImport.create_from_file( + client=client, project_id=project.uid, name=name, path=str(file_path) + ) assert label_import.parent_id == project.uid annotation_import_test_helpers.check_running_state(label_import, name) annotation_import_test_helpers.assert_file_content( - label_import.input_file_url, object_predictions) + label_import.input_file_url, object_predictions + ) -def test_get(client, module_project, - annotation_import_test_helpers): +def test_get(client, module_project, annotation_import_test_helpers): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" label_import = LabelImport.create_from_url( - client=client, - project_id=module_project.uid, - name=name, - url=url) + client=client, project_id=module_project.uid, name=name, url=url + ) assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name, url) @@ -130,18 +143,19 @@ def test_wait_till_done(client, module_project, predictions): client=client, project_id=module_project.uid, name=name, - labels=predictions) + labels=predictions, + ) assert len(label_import.inputs) == len(predictions) label_import.wait_until_done() assert label_import.state == AnnotationImportState.FINISHED assert len(label_import.inputs) == len(predictions) - input_uuids = [input_annot['uuid'] for input_annot in label_import.inputs] - inference_uuids = [pred['uuid'] for pred in predictions] + input_uuids = [input_annot["uuid"] for input_annot in label_import.inputs] + inference_uuids = [pred["uuid"] for pred in predictions] assert set(input_uuids) == set(inference_uuids) assert len(label_import.statuses) == len(predictions) status_uuids = [ - input_annot['uuid'] for input_annot in label_import.statuses + input_annot["uuid"] for input_annot in label_import.statuses ] assert set(input_uuids) == set(status_uuids) diff --git a/libs/labelbox/tests/data/annotation_import/test_mal_prediction_import.py b/libs/labelbox/tests/data/annotation_import/test_mal_prediction_import.py index c50c82315..3ffd6bfc1 100644 --- a/libs/labelbox/tests/data/annotation_import/test_mal_prediction_import.py +++ b/libs/labelbox/tests/data/annotation_import/test_mal_prediction_import.py @@ -2,6 +2,7 @@ from labelbox import parser from labelbox.schema.annotation_import import MALPredictionImport + """ - Here we only want to check that the uploads are calling the validation - Then with unit tests we can check the types of errors raised @@ -9,37 +10,45 @@ """ -def test_create_with_url_arg(client, module_project, - annotation_import_test_helpers): +def test_create_with_url_arg( + client, module_project, annotation_import_test_helpers +): name = str(uuid.uuid4()) url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" label_import = MALPredictionImport.create( - client=client, - id=module_project.uid, - name=name, - url=url) + client=client, id=module_project.uid, name=name, url=url + ) assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name, url) -def test_create_with_labels_arg(client, module_project, object_predictions, - annotation_import_test_helpers): +def test_create_with_labels_arg( + client, module_project, object_predictions, annotation_import_test_helpers +): """this test should check running state only to validate running, not completed""" name = str(uuid.uuid4()) - label_import = MALPredictionImport.create(client=client, - id=module_project.uid, - name=name, - labels=object_predictions) + label_import = MALPredictionImport.create( + client=client, + id=module_project.uid, + name=name, + labels=object_predictions, + ) assert label_import.parent_id == module_project.uid annotation_import_test_helpers.check_running_state(label_import, name) annotation_import_test_helpers.assert_file_content( - label_import.input_file_url, object_predictions) + label_import.input_file_url, object_predictions + ) -def test_create_with_path_arg(client, tmp_path, configured_project, object_predictions, - annotation_import_test_helpers): +def test_create_with_path_arg( + client, + tmp_path, + configured_project, + object_predictions, + annotation_import_test_helpers, +): project = configured_project name = str(uuid.uuid4()) file_name = f"{name}.ndjson" @@ -47,12 +56,12 @@ def test_create_with_path_arg(client, tmp_path, configured_project, object_predi with file_path.open("w") as f: parser.dump(object_predictions, f) - label_import = MALPredictionImport.create(client=client, - id=project.uid, - name=name, - path=str(file_path)) + label_import = MALPredictionImport.create( + client=client, id=project.uid, name=name, path=str(file_path) + ) assert label_import.parent_id == project.uid annotation_import_test_helpers.check_running_state(label_import, name) annotation_import_test_helpers.assert_file_content( - label_import.input_file_url, object_predictions) + label_import.input_file_url, object_predictions + ) diff --git a/libs/labelbox/tests/data/annotation_import/test_mea_prediction_import.py b/libs/labelbox/tests/data/annotation_import/test_mea_prediction_import.py index f2765fd3f..fccca2a3f 100644 --- a/libs/labelbox/tests/data/annotation_import/test_mea_prediction_import.py +++ b/libs/labelbox/tests/data/annotation_import/test_mea_prediction_import.py @@ -3,125 +3,160 @@ import pytest from labelbox import ModelRun -from labelbox.schema.annotation_import import AnnotationImportState, MEAPredictionImport +from labelbox.schema.annotation_import import ( + AnnotationImportState, + MEAPredictionImport, +) from labelbox.data.serialization import NDJsonConverter from labelbox.schema.export_params import ModelRunExportParams + """ - Here we only want to check that the uploads are calling the validation - Then with unit tests we can check the types of errors raised """ + @pytest.mark.order(1) -def test_create_from_objects(model_run_with_data_rows, - object_predictions_for_annotation_import, - annotation_import_test_helpers): +def test_create_from_objects( + model_run_with_data_rows, + object_predictions_for_annotation_import, + annotation_import_test_helpers, +): name = str(uuid.uuid4()) object_predictions = object_predictions_for_annotation_import - use_data_row_ids = [p['dataRow']['id'] for p in object_predictions] + use_data_row_ids = [p["dataRow"]["id"] for p in object_predictions] model_run_with_data_rows.upsert_data_rows(use_data_row_ids) annotation_import = model_run_with_data_rows.add_predictions( - name=name, predictions=object_predictions) + name=name, predictions=object_predictions + ) assert annotation_import.model_run_id == model_run_with_data_rows.uid annotation_import_test_helpers.check_running_state(annotation_import, name) annotation_import_test_helpers.assert_file_content( - annotation_import.input_file_url, object_predictions) + annotation_import.input_file_url, object_predictions + ) annotation_import.wait_until_done() assert annotation_import.state == AnnotationImportState.FINISHED annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) + annotation_import.status_file_url + ) -def test_create_from_objects_global_key(client, model_run_with_data_rows, - polygon_inference, - annotation_import_test_helpers): +def test_create_from_objects_global_key( + client, + model_run_with_data_rows, + polygon_inference, + annotation_import_test_helpers, +): name = str(uuid.uuid4()) - dr = client.get_data_row(polygon_inference[0]['dataRow']['id']) - polygon_inference[0]['dataRow']['globalKey'] = dr.global_key - del polygon_inference[0]['dataRow']['id'] + dr = client.get_data_row(polygon_inference[0]["dataRow"]["id"]) + polygon_inference[0]["dataRow"]["globalKey"] = dr.global_key + del polygon_inference[0]["dataRow"]["id"] object_predictions = [polygon_inference[0]] annotation_import = model_run_with_data_rows.add_predictions( - name=name, predictions=object_predictions) + name=name, predictions=object_predictions + ) assert annotation_import.model_run_id == model_run_with_data_rows.uid annotation_import_test_helpers.check_running_state(annotation_import, name) annotation_import_test_helpers.assert_file_content( - annotation_import.input_file_url, object_predictions) + annotation_import.input_file_url, object_predictions + ) annotation_import.wait_until_done() assert annotation_import.state == AnnotationImportState.FINISHED annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) + annotation_import.status_file_url + ) -def test_create_from_objects_with_confidence(predictions_with_confidence, - model_run_with_data_rows, - annotation_import_test_helpers): +def test_create_from_objects_with_confidence( + predictions_with_confidence, + model_run_with_data_rows, + annotation_import_test_helpers, +): name = str(uuid.uuid4()) - object_prediction_data_rows = set([ - object_prediction["dataRow"]["id"] - for object_prediction in predictions_with_confidence - ]) + object_prediction_data_rows = set( + [ + object_prediction["dataRow"]["id"] + for object_prediction in predictions_with_confidence + ] + ) # MUST have all data rows in the model run model_run_with_data_rows.upsert_data_rows( - data_row_ids=list(object_prediction_data_rows)) + data_row_ids=list(object_prediction_data_rows) + ) annotation_import = model_run_with_data_rows.add_predictions( - name=name, predictions=predictions_with_confidence) + name=name, predictions=predictions_with_confidence + ) assert annotation_import.model_run_id == model_run_with_data_rows.uid annotation_import_test_helpers.check_running_state(annotation_import, name) annotation_import_test_helpers.assert_file_content( - annotation_import.input_file_url, predictions_with_confidence) + annotation_import.input_file_url, predictions_with_confidence + ) annotation_import.wait_until_done() assert annotation_import.state == AnnotationImportState.FINISHED annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) + annotation_import.status_file_url + ) def test_create_from_objects_all_project_labels( - model_run_with_all_project_labels, - object_predictions_for_annotation_import, - annotation_import_test_helpers): + model_run_with_all_project_labels, + object_predictions_for_annotation_import, + annotation_import_test_helpers, +): name = str(uuid.uuid4()) object_predictions = object_predictions_for_annotation_import - use_data_row_ids = [p['dataRow']['id'] for p in object_predictions] + use_data_row_ids = [p["dataRow"]["id"] for p in object_predictions] model_run_with_all_project_labels.upsert_data_rows(use_data_row_ids) annotation_import = model_run_with_all_project_labels.add_predictions( - name=name, predictions=object_predictions) + name=name, predictions=object_predictions + ) - assert annotation_import.model_run_id == model_run_with_all_project_labels.uid + assert ( + annotation_import.model_run_id == model_run_with_all_project_labels.uid + ) annotation_import_test_helpers.check_running_state(annotation_import, name) annotation_import_test_helpers.assert_file_content( - annotation_import.input_file_url, object_predictions) + annotation_import.input_file_url, object_predictions + ) annotation_import.wait_until_done() assert annotation_import.state == AnnotationImportState.FINISHED annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) + annotation_import.status_file_url + ) -def test_model_run_project_labels(model_run_with_all_project_labels: ModelRun, - model_run_predictions): - +def test_model_run_project_labels( + model_run_with_all_project_labels: ModelRun, model_run_predictions +): model_run = model_run_with_all_project_labels export_task = model_run.export() export_task.wait_till_done() stream = export_task.get_buffered_stream() - + # exports to list of tuples (data_row_id, label) needed to adapt test to export v2 instead of export v1 since data rows ids are not at label level in export v2. - model_run_exported_labels = [( - data_row.json["data_row"]["id"], - data_row.json["experiments"][model_run.model_id]["runs"][model_run.uid]["labels"][0]) - for data_row in stream] - + model_run_exported_labels = [ + ( + data_row.json["data_row"]["id"], + data_row.json["experiments"][model_run.model_id]["runs"][ + model_run.uid + ]["labels"][0], + ) + for data_row in stream + ] + labels_indexed_by_name = {} # making sure the labels are in this model run are all labels uploaded to the project @@ -130,51 +165,69 @@ def test_model_run_project_labels(model_run_with_all_project_labels: ModelRun, for data_row_id, label in model_run_exported_labels: for object in label["annotations"]["objects"]: name = object["name"] - labels_indexed_by_name[f"{name}-{data_row_id}"] = {"label": label, "data_row_id": data_row_id} - - assert (len( - labels_indexed_by_name.keys())) == len([prediction["dataRow"]["id"] for prediction in model_run_predictions]) - - expected_data_row_ids = set([prediction["dataRow"]["id"] for prediction in model_run_predictions]) - expected_objects = set([prediction["name"] for prediction in model_run_predictions]) + labels_indexed_by_name[f"{name}-{data_row_id}"] = { + "label": label, + "data_row_id": data_row_id, + } + + assert (len(labels_indexed_by_name.keys())) == len( + [prediction["dataRow"]["id"] for prediction in model_run_predictions] + ) + + expected_data_row_ids = set( + [prediction["dataRow"]["id"] for prediction in model_run_predictions] + ) + expected_objects = set( + [prediction["name"] for prediction in model_run_predictions] + ) for data_row_id, actual_label in model_run_exported_labels: assert data_row_id in expected_data_row_ids - assert len(expected_objects) == len(actual_label["annotations"]["objects"]) + assert len(expected_objects) == len( + actual_label["annotations"]["objects"] + ) - -def test_create_from_label_objects(model_run_with_data_rows, - object_predictions_for_annotation_import, - annotation_import_test_helpers): +def test_create_from_label_objects( + model_run_with_data_rows, + object_predictions_for_annotation_import, + annotation_import_test_helpers, +): name = str(uuid.uuid4()) use_data_row_ids = [ - p['dataRow']['id'] for p in object_predictions_for_annotation_import + p["dataRow"]["id"] for p in object_predictions_for_annotation_import ] model_run_with_data_rows.upsert_data_rows(use_data_row_ids) predictions = list( - NDJsonConverter.deserialize(object_predictions_for_annotation_import)) + NDJsonConverter.deserialize(object_predictions_for_annotation_import) + ) annotation_import = model_run_with_data_rows.add_predictions( - name=name, predictions=predictions) + name=name, predictions=predictions + ) assert annotation_import.model_run_id == model_run_with_data_rows.uid annotation_import_test_helpers.check_running_state(annotation_import, name) normalized_predictions = NDJsonConverter.serialize(predictions) annotation_import_test_helpers.assert_file_content( - annotation_import.input_file_url, normalized_predictions) + annotation_import.input_file_url, normalized_predictions + ) annotation_import.wait_until_done() assert annotation_import.state == AnnotationImportState.FINISHED annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) + annotation_import.status_file_url + ) -def test_create_from_local_file(tmp_path, model_run_with_data_rows, - object_predictions_for_annotation_import, - annotation_import_test_helpers): +def test_create_from_local_file( + tmp_path, + model_run_with_data_rows, + object_predictions_for_annotation_import, + annotation_import_test_helpers, +): use_data_row_ids = [ - p['dataRow']['id'] for p in object_predictions_for_annotation_import + p["dataRow"]["id"] for p in object_predictions_for_annotation_import ] model_run_with_data_rows.upsert_data_rows(use_data_row_ids) @@ -185,30 +238,36 @@ def test_create_from_local_file(tmp_path, model_run_with_data_rows, parser.dump(object_predictions_for_annotation_import, f) annotation_import = model_run_with_data_rows.add_predictions( - name=name, predictions=str(file_path)) + name=name, predictions=str(file_path) + ) assert annotation_import.model_run_id == model_run_with_data_rows.uid annotation_import_test_helpers.check_running_state(annotation_import, name) annotation_import_test_helpers.assert_file_content( annotation_import.input_file_url, - object_predictions_for_annotation_import) + object_predictions_for_annotation_import, + ) annotation_import.wait_until_done() assert annotation_import.state == AnnotationImportState.FINISHED annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) + annotation_import.status_file_url + ) def test_predictions_with_custom_metrics( - model_run, object_predictions_for_annotation_import, - annotation_import_test_helpers): + model_run, + object_predictions_for_annotation_import, + annotation_import_test_helpers, +): name = str(uuid.uuid4()) object_predictions = object_predictions_for_annotation_import - use_data_row_ids = [p['dataRow']['id'] for p in object_predictions] + use_data_row_ids = [p["dataRow"]["id"] for p in object_predictions] model_run.upsert_data_rows(use_data_row_ids) annotation_import = model_run.add_predictions( - name=name, predictions=object_predictions) + name=name, predictions=object_predictions + ) assert annotation_import.model_run_id == model_run.uid annotation_import.wait_until_done() @@ -219,7 +278,8 @@ def test_predictions_with_custom_metrics( assert annotation_import.state == AnnotationImportState.FINISHED annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) + annotation_import.status_file_url + ) def test_get(client, model_run_with_data_rows, annotation_import_test_helpers): @@ -228,11 +288,13 @@ def test_get(client, model_run_with_data_rows, annotation_import_test_helpers): model_run_with_data_rows.add_predictions(name=name, predictions=url) annotation_import = MEAPredictionImport.from_name( - client, model_run_id=model_run_with_data_rows.uid, name=name) + client, model_run_id=model_run_with_data_rows.uid, name=name + ) assert annotation_import.model_run_id == model_run_with_data_rows.uid - annotation_import_test_helpers.check_running_state(annotation_import, name, - url) + annotation_import_test_helpers.check_running_state( + annotation_import, name, url + ) annotation_import.wait_until_done() @@ -240,7 +302,8 @@ def test_get(client, model_run_with_data_rows, annotation_import_test_helpers): def test_wait_till_done(model_run_predictions, model_run_with_data_rows): name = str(uuid.uuid4()) annotation_import = model_run_with_data_rows.add_predictions( - name=name, predictions=model_run_predictions) + name=name, predictions=model_run_predictions + ) assert len(annotation_import.inputs) == len(model_run_predictions) annotation_import.wait_until_done() @@ -249,14 +312,14 @@ def test_wait_till_done(model_run_predictions, model_run_with_data_rows): assert len(annotation_import.errors) == 0 assert len(annotation_import.inputs) == len(model_run_predictions) input_uuids = [ - input_annot['uuid'] for input_annot in annotation_import.inputs + input_annot["uuid"] for input_annot in annotation_import.inputs ] - inference_uuids = [pred['uuid'] for pred in model_run_predictions] + inference_uuids = [pred["uuid"] for pred in model_run_predictions] assert set(input_uuids) == set(inference_uuids) assert len(annotation_import.statuses) == len(model_run_predictions) for status in annotation_import.statuses: - assert status['status'] == 'SUCCESS' + assert status["status"] == "SUCCESS" status_uuids = [ - input_annot['uuid'] for input_annot in annotation_import.statuses + input_annot["uuid"] for input_annot in annotation_import.statuses ] assert set(input_uuids) == set(status_uuids) diff --git a/libs/labelbox/tests/data/annotation_import/test_model_run.py b/libs/labelbox/tests/data/annotation_import/test_model_run.py index bf30ed169..9eca28429 100644 --- a/libs/labelbox/tests/data/annotation_import/test_model_run.py +++ b/libs/labelbox/tests/data/annotation_import/test_model_run.py @@ -6,6 +6,7 @@ from labelbox import DataSplit, ModelRun + @pytest.mark.order(1) def test_model_run(client, configured_project_with_label, data_row, rand_gen): project, _, _, label = configured_project_with_label @@ -87,19 +88,19 @@ def test_model_run_data_rows_delete(model_run_with_data_rows): assert len(before) == len(after) + 1 -def test_model_run_upsert_data_rows(dataset, model_run, - configured_project): +def test_model_run_upsert_data_rows(dataset, model_run, configured_project): n_model_run_data_rows = len(list(model_run.model_run_data_rows())) assert n_model_run_data_rows == 0 data_row = dataset.create_data_row(row_data="test row data") configured_project._wait_until_data_rows_are_processed( - data_row_ids=[data_row.uid]) + data_row_ids=[data_row.uid] + ) model_run.upsert_data_rows([data_row.uid]) n_model_run_data_rows = len(list(model_run.model_run_data_rows())) assert n_model_run_data_rows == 1 -@pytest.mark.parametrize('data_rows', [2], indirect=True) +@pytest.mark.parametrize("data_rows", [2], indirect=True) def test_model_run_upsert_data_rows_using_global_keys(model_run, data_rows): global_keys = [dr.global_key for dr in data_rows] assert model_run.upsert_data_rows(global_keys=global_keys) @@ -109,68 +110,77 @@ def test_model_run_upsert_data_rows_using_global_keys(model_run, data_rows): def test_model_run_upsert_data_rows_with_existing_labels( - model_run_with_data_rows): + model_run_with_data_rows, +): model_run_data_rows = list(model_run_with_data_rows.model_run_data_rows()) n_data_rows = len(model_run_data_rows) - model_run_with_data_rows.upsert_data_rows([ - model_run_data_row.data_row().uid - for model_run_data_row in model_run_data_rows - ]) + model_run_with_data_rows.upsert_data_rows( + [ + model_run_data_row.data_row().uid + for model_run_data_row in model_run_data_rows + ] + ) assert n_data_rows == len( - list(model_run_with_data_rows.model_run_data_rows())) + list(model_run_with_data_rows.model_run_data_rows()) + ) -@pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", - reason="does not work for onprem") +@pytest.mark.skipif( + condition=os.environ["LABELBOX_TEST_ENVIRON"] == "onprem", + reason="does not work for onprem", +) def test_model_run_status(model_run_with_data_rows): - def get_model_run_status(): return model_run_with_data_rows.client.execute( """query trainingPipelinePyApi($modelRunId: ID!) { trainingPipeline(where: {id : $modelRunId}) {status, errorMessage, metadata}} - """, {'modelRunId': model_run_with_data_rows.uid}, - experimental=True)['trainingPipeline'] + """, + {"modelRunId": model_run_with_data_rows.uid}, + experimental=True, + )["trainingPipeline"] model_run_status = get_model_run_status() - assert model_run_status['status'] is None - assert model_run_status['metadata'] is None - assert model_run_status['errorMessage'] is None + assert model_run_status["status"] is None + assert model_run_status["metadata"] is None + assert model_run_status["errorMessage"] is None status = "COMPLETE" - metadata = {'key1': 'value1'} + metadata = {"key1": "value1"} errorMessage = "an error" model_run_with_data_rows.update_status(status, metadata, errorMessage) model_run_status = get_model_run_status() - assert model_run_status['status'] == status - assert model_run_status['metadata'] == metadata - assert model_run_status['errorMessage'] == errorMessage + assert model_run_status["status"] == status + assert model_run_status["metadata"] == metadata + assert model_run_status["errorMessage"] == errorMessage - extra_metadata = {'key2': 'value2'} + extra_metadata = {"key2": "value2"} model_run_with_data_rows.update_status(status, extra_metadata) model_run_status = get_model_run_status() - assert model_run_status['status'] == status - assert model_run_status['metadata'] == {**metadata, **extra_metadata} - assert model_run_status['errorMessage'] == errorMessage + assert model_run_status["status"] == status + assert model_run_status["metadata"] == {**metadata, **extra_metadata} + assert model_run_status["errorMessage"] == errorMessage status = ModelRun.Status.FAILED model_run_with_data_rows.update_status(status, metadata, errorMessage) model_run_status = get_model_run_status() - assert model_run_status['status'] == status.value + assert model_run_status["status"] == status.value with pytest.raises(ValueError): - model_run_with_data_rows.update_status("INVALID", metadata, - errorMessage) + model_run_with_data_rows.update_status( + "INVALID", metadata, errorMessage + ) -def test_model_run_split_assignment_by_data_row_ids(model_run, dataset, - image_url): +def test_model_run_split_assignment_by_data_row_ids( + model_run, dataset, image_url +): n_data_rows = 2 - data_rows = dataset.create_data_rows([{ - "row_data": image_url - } for _ in range(n_data_rows)]) + data_rows = dataset.create_data_rows( + [{"row_data": image_url} for _ in range(n_data_rows)] + ) data_rows.wait_till_done() - data_row_ids = [data_row['id'] for data_row in data_rows.result] + data_row_ids = [data_row["id"] for data_row in data_rows.result] model_run.upsert_data_rows(data_row_ids) with pytest.raises(ValueError): @@ -185,15 +195,16 @@ def test_model_run_split_assignment_by_data_row_ids(model_run, dataset, assert counts[split] == n_data_rows -@pytest.mark.parametrize('data_rows', [2], indirect=True) +@pytest.mark.parametrize("data_rows", [2], indirect=True) def test_model_run_split_assignment_by_global_keys(model_run, data_rows): global_keys = [data_row.global_key for data_row in data_rows] model_run.upsert_data_rows(global_keys=global_keys) for split in ["TRAINING", "TEST", "VALIDATION", "UNASSIGNED", *DataSplit]: - model_run.assign_data_rows_to_split(split=split, - global_keys=global_keys) + model_run.assign_data_rows_to_split( + split=split, global_keys=global_keys + ) splits = [ data_row.data_split.value for data_row in model_run.model_run_data_rows() diff --git a/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py b/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py index ac197a321..a0df559fc 100644 --- a/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py +++ b/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py @@ -6,15 +6,25 @@ from pytest_cases import parametrize, fixture_ref from labelbox.exceptions import MALValidationError -from labelbox.schema.bulk_import_request import (NDChecklist, NDClassification, - NDMask, NDPolygon, NDPolyline, - NDRadio, NDRectangle, NDText, - NDTextEntity, NDTool, - _validate_ndjson) +from labelbox.schema.bulk_import_request import ( + NDChecklist, + NDClassification, + NDMask, + NDPolygon, + NDPolyline, + NDRadio, + NDRectangle, + NDText, + NDTextEntity, + NDTool, + _validate_ndjson, +) + """ - These NDlabels are apart of bulkImportReqeust and should be removed once bulk import request is removed """ + def test_classification_construction(checklist_inference, text_inference): checklist = NDClassification.build(checklist_inference[0]) assert isinstance(checklist, NDChecklist) @@ -22,97 +32,93 @@ def test_classification_construction(checklist_inference, text_inference): assert isinstance(text, NDText) -@parametrize("inference, expected_type", - [(fixture_ref('polygon_inference'), NDPolygon), - (fixture_ref('rectangle_inference'), NDRectangle), - (fixture_ref('line_inference'), NDPolyline), - (fixture_ref('entity_inference'), NDTextEntity), - (fixture_ref('segmentation_inference'), NDMask), - (fixture_ref('segmentation_inference_rle'), NDMask), - (fixture_ref('segmentation_inference_png'), NDMask)]) +@parametrize( + "inference, expected_type", + [ + (fixture_ref("polygon_inference"), NDPolygon), + (fixture_ref("rectangle_inference"), NDRectangle), + (fixture_ref("line_inference"), NDPolyline), + (fixture_ref("entity_inference"), NDTextEntity), + (fixture_ref("segmentation_inference"), NDMask), + (fixture_ref("segmentation_inference_rle"), NDMask), + (fixture_ref("segmentation_inference_png"), NDMask), + ], +) def test_tool_construction(inference, expected_type): assert isinstance(NDTool.build(inference[0]), expected_type) def no_tool(text_inference, module_project): pred = text_inference[0].copy() - #Missing key - del pred['answer'] + # Missing key + del pred["answer"] with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) -@pytest.mark.parametrize( - "configured_project", - [MediaType.Text], - indirect=True -) + +@pytest.mark.parametrize("configured_project", [MediaType.Text], indirect=True) def test_invalid_text(text_inference, configured_project): - #and if it is not a string + # and if it is not a string pred = text_inference[0].copy() - #Extra and wrong key - del pred['answer'] - pred['answers'] = [] + # Extra and wrong key + del pred["answer"] + pred["answers"] = [] with pytest.raises(MALValidationError): _validate_ndjson([pred], configured_project) - del pred['answers'] + del pred["answers"] - #Invalid type - pred['answer'] = [] + # Invalid type + pred["answer"] = [] with pytest.raises(MALValidationError): _validate_ndjson([pred], configured_project) - #Invalid type - pred['answer'] = None + # Invalid type + pred["answer"] = None with pytest.raises(MALValidationError): _validate_ndjson([pred], configured_project) -def test_invalid_checklist_item(checklist_inference, - module_project): - #Only two points +def test_invalid_checklist_item(checklist_inference, module_project): + # Only two points pred = checklist_inference[0].copy() - pred['answers'] = [pred['answers'][0], pred['answers'][0]] - #Duplicate schema ids + pred["answers"] = [pred["answers"][0], pred["answers"][0]] + # Duplicate schema ids with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) - pred['answers'] = [{"name": "asdfg"}] + pred["answers"] = [{"name": "asdfg"}] with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) - pred['answers'] = [{"schemaId": "1232132132"}] + pred["answers"] = [{"schemaId": "1232132132"}] with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) - pred['answers'] = [{}] + pred["answers"] = [{}] with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) - pred['answers'] = [] + pred["answers"] = [] with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) - del pred['answers'] + del pred["answers"] with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) def test_invalid_polygon(polygon_inference, module_project): - #Only two points + # Only two points pred = polygon_inference[0].copy() - pred['polygon'] = [{"x": 100, "y": 100}, {"x": 200, "y": 200}] + pred["polygon"] = [{"x": 100, "y": 100}, {"x": 200, "y": 200}] with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) -@pytest.mark.parametrize( - "configured_project", - [MediaType.Text], - indirect=True -) +@pytest.mark.parametrize("configured_project", [MediaType.Text], indirect=True) def test_incorrect_entity(entity_inference, configured_project): entity = entity_inference[0].copy() - #Location cannot be a list + # Location cannot be a list entity["location"] = [0, 10] with pytest.raises(MALValidationError): _validate_ndjson([entity], configured_project) @@ -126,53 +132,50 @@ def test_incorrect_entity(entity_inference, configured_project): _validate_ndjson([entity], configured_project) -@pytest.mark.skip("Test wont work/fails randomly since projects have to have a media type and could be missing features from prediction list") +@pytest.mark.skip( + "Test wont work/fails randomly since projects have to have a media type and could be missing features from prediction list" +) def test_all_validate_json(module_project, predictions): - #Predictions contains one of each type of prediction. - #These should be properly formatted and pass. + # Predictions contains one of each type of prediction. + # These should be properly formatted and pass. _validate_ndjson(predictions[0], module_project) def test_incorrect_line(line_inference, module_project): line = line_inference[0].copy() - line["line"] = [line["line"][0]] #Just one point + line["line"] = [line["line"][0]] # Just one point with pytest.raises(MALValidationError): _validate_ndjson([line], module_project) -def test_incorrect_rectangle(rectangle_inference, - module_project): - del rectangle_inference[0]['bbox']['top'] +def test_incorrect_rectangle(rectangle_inference, module_project): + del rectangle_inference[0]["bbox"]["top"] with pytest.raises(MALValidationError): - _validate_ndjson([rectangle_inference], - module_project) + _validate_ndjson([rectangle_inference], module_project) def test_duplicate_tools(rectangle_inference, module_project): pred = rectangle_inference[0].copy() - pred['polygon'] = [{"x": 100, "y": 100}, {"x": 200, "y": 200}] + pred["polygon"] = [{"x": 100, "y": 100}, {"x": 200, "y": 200}] with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) -def test_invalid_feature_schema(module_project, - rectangle_inference): +def test_invalid_feature_schema(module_project, rectangle_inference): pred = rectangle_inference[0].copy() - pred['schemaId'] = "blahblah" + pred["schemaId"] = "blahblah" with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) -def test_name_only_feature_schema(module_project, - rectangle_inference): +def test_name_only_feature_schema(module_project, rectangle_inference): pred = rectangle_inference[0].copy() _validate_ndjson([pred], module_project) -def test_schema_id_only_feature_schema(module_project, - rectangle_inference): +def test_schema_id_only_feature_schema(module_project, rectangle_inference): pred = rectangle_inference[0].copy() - del pred['name'] + del pred["name"] ontology = module_project.ontology().normalized["tools"] for tool in ontology: if tool["name"] == "bbox": @@ -181,10 +184,9 @@ def test_schema_id_only_feature_schema(module_project, _validate_ndjson([pred], module_project) -def test_missing_feature_schema(module_project, - rectangle_inference): +def test_missing_feature_schema(module_project, rectangle_inference): pred = rectangle_inference[0].copy() - del pred['name'] + del pred["name"] with pytest.raises(MALValidationError): _validate_ndjson([pred], module_project) @@ -197,31 +199,32 @@ def test_validate_ndjson(tmp_path, configured_project): with pytest.raises(ValueError): configured_project.upload_annotations( - name="name", annotations=str(file_path), validate=True) + name="name", annotations=str(file_path), validate=True + ) -def test_validate_ndjson_uuid(tmp_path, configured_project, - predictions): +def test_validate_ndjson_uuid(tmp_path, configured_project, predictions): file_name = f"repeat_uuid.ndjson" file_path = tmp_path / file_name repeat_uuid = predictions.copy() - repeat_uuid[0]['uuid'] = 'test_uuid' - repeat_uuid[1]['uuid'] = 'test_uuid' + repeat_uuid[0]["uuid"] = "test_uuid" + repeat_uuid[1]["uuid"] = "test_uuid" with file_path.open("w") as f: parser.dump(repeat_uuid, f) with pytest.raises(MALValidationError): configured_project.upload_annotations( - name="name", validate=True, annotations=str(file_path)) + name="name", validate=True, annotations=str(file_path) + ) with pytest.raises(MALValidationError): configured_project.upload_annotations( - name="name", validate=True, annotations=repeat_uuid) + name="name", validate=True, annotations=repeat_uuid + ) @pytest.mark.parametrize("configured_project", [MediaType.Video], indirect=True) -def test_video_upload(video_checklist_inference, - configured_project): +def test_video_upload(video_checklist_inference, configured_project): pred = video_checklist_inference[0].copy() _validate_ndjson([pred], configured_project) diff --git a/libs/labelbox/tests/data/annotation_import/test_send_to_annotate_mea.py b/libs/labelbox/tests/data/annotation_import/test_send_to_annotate_mea.py index 1f8b84742..4bcd4dcef 100644 --- a/libs/labelbox/tests/data/annotation_import/test_send_to_annotate_mea.py +++ b/libs/labelbox/tests/data/annotation_import/test_send_to_annotate_mea.py @@ -1,14 +1,22 @@ import pytest from labelbox import UniqueIds, OntologyBuilder -from labelbox.schema.conflict_resolution_strategy import ConflictResolutionStrategy +from labelbox.schema.conflict_resolution_strategy import ( + ConflictResolutionStrategy, +) -def test_send_to_annotate_from_model(client, configured_project, - model_run_predictions, - model_run_with_data_rows, project): +def test_send_to_annotate_from_model( + client, + configured_project, + model_run_predictions, + model_run_with_data_rows, + project, +): model_run = model_run_with_data_rows - data_row_ids = list(set([p['dataRow']['id'] for p in model_run_predictions])) + data_row_ids = list( + set([p["dataRow"]["id"] for p in model_run_predictions]) + ) assert len(data_row_ids) > 0 destination_project = project @@ -18,22 +26,27 @@ def test_send_to_annotate_from_model(client, configured_project, queues = destination_project.task_queues() initial_review_task = next( - q for q in queues if q.name == "Initial review task") + q for q in queues if q.name == "Initial review task" + ) # build an ontology mapping using the top level tools and classifications source_ontology_builder = OntologyBuilder.from_project(configured_project) feature_schema_ids = list( - tool.feature_schema_id for tool in source_ontology_builder.tools) + tool.feature_schema_id for tool in source_ontology_builder.tools + ) # create a dictionary of feature schema id to itself ontology_mapping = dict(zip(feature_schema_ids, feature_schema_ids)) classification_feature_schema_ids = list( classification.feature_schema_id - for classification in source_ontology_builder.classifications) + for classification in source_ontology_builder.classifications + ) # create a dictionary of feature schema id to itself classification_ontology_mapping = dict( - zip(classification_feature_schema_ids, - classification_feature_schema_ids)) + zip( + classification_feature_schema_ids, classification_feature_schema_ids + ) + ) # combine the two ontology mappings ontology_mapping.update(classification_ontology_mapping) @@ -44,11 +57,10 @@ def test_send_to_annotate_from_model(client, configured_project, data_rows=UniqueIds(data_row_ids), task_queue_id=initial_review_task.uid, params={ - "predictions_ontology_mapping": - ontology_mapping, - "override_existing_annotations_rule": - ConflictResolutionStrategy.OverrideWithPredictions - }) + "predictions_ontology_mapping": ontology_mapping, + "override_existing_annotations_rule": ConflictResolutionStrategy.OverrideWithPredictions, + }, + ) task.wait_till_done() @@ -66,5 +78,5 @@ def test_send_to_annotate_from_model(client, configured_project, assert all([dr in data_row_ids for dr in destination_data_rows]) # Since data rows were added to a review queue, predictions should be imported into the project as labels - destination_project_labels = (list(destination_project.labels())) + destination_project_labels = list(destination_project.labels()) assert len(destination_project_labels) == len(data_row_ids) diff --git a/libs/labelbox/tests/data/annotation_import/test_upsert_prediction_import.py b/libs/labelbox/tests/data/annotation_import/test_upsert_prediction_import.py index 59c894c65..a60e0aa59 100644 --- a/libs/labelbox/tests/data/annotation_import/test_upsert_prediction_import.py +++ b/libs/labelbox/tests/data/annotation_import/test_upsert_prediction_import.py @@ -1,6 +1,7 @@ import uuid from labelbox import parser import pytest + """ - Here we only want to check that the uploads are calling the validation - Then with unit tests we can check the types of errors raised @@ -9,10 +10,14 @@ @pytest.mark.skip() -def test_create_from_url(client, tmp_path, object_predictions, - model_run_with_data_rows, - configured_project, - annotation_import_test_helpers): +def test_create_from_url( + client, + tmp_path, + object_predictions, + model_run_with_data_rows, + configured_project, + annotation_import_test_helpers, +): name = str(uuid.uuid4()) file_name = f"{name}.json" file_path = tmp_path / file_name @@ -22,8 +27,9 @@ def test_create_from_url(client, tmp_path, object_predictions, for mrdr in model_run_with_data_rows.model_run_data_rows() ] predictions = [ - p for p in object_predictions - if p['dataRow']['id'] in model_run_data_rows + p + for p in object_predictions + if p["dataRow"]["id"] in model_run_data_rows ] with file_path.open("w") as f: parser.dump(predictions, f) @@ -31,16 +37,21 @@ def test_create_from_url(client, tmp_path, object_predictions, # Needs to have data row ids with open(file_path, "r") as f: - url = client.upload_data(content=f.read(), - filename=file_name, - sign=True, - content_type="application/json") - - annotation_import, batch, mal_prediction_import = model_run_with_data_rows.upsert_predictions_and_send_to_project( - name=name, - predictions=url, - project_id=configured_project.uid, - priority=5) + url = client.upload_data( + content=f.read(), + filename=file_name, + sign=True, + content_type="application/json", + ) + + annotation_import, batch, mal_prediction_import = ( + model_run_with_data_rows.upsert_predictions_and_send_to_project( + name=name, + predictions=url, + project_id=configured_project.uid, + priority=5, + ) + ) assert annotation_import.model_run_id == model_run_with_data_rows.uid annotation_import.wait_until_done() @@ -58,24 +69,30 @@ def test_create_from_url(client, tmp_path, object_predictions, @pytest.mark.skip() -def test_create_from_objects(model_run_with_data_rows, - configured_project, - object_predictions, - annotation_import_test_helpers): +def test_create_from_objects( + model_run_with_data_rows, + configured_project, + object_predictions, + annotation_import_test_helpers, +): name = str(uuid.uuid4()) model_run_data_rows = [ mrdr.data_row().uid for mrdr in model_run_with_data_rows.model_run_data_rows() ] predictions = [ - p for p in object_predictions - if p['dataRow']['id'] in model_run_data_rows + p + for p in object_predictions + if p["dataRow"]["id"] in model_run_data_rows ] - annotation_import, batch, mal_prediction_import = model_run_with_data_rows.upsert_predictions_and_send_to_project( - name=name, - predictions=predictions, - project_id=configured_project.uid, - priority=5) + annotation_import, batch, mal_prediction_import = ( + model_run_with_data_rows.upsert_predictions_and_send_to_project( + name=name, + predictions=predictions, + project_id=configured_project.uid, + priority=5, + ) + ) assert annotation_import.model_run_id == model_run_with_data_rows.uid annotation_import.wait_until_done() @@ -93,11 +110,13 @@ def test_create_from_objects(model_run_with_data_rows, @pytest.mark.skip() -def test_create_from_local_file(tmp_path, model_run_with_data_rows, - configured_project_with_one_data_row, - object_predictions, - annotation_import_test_helpers): - +def test_create_from_local_file( + tmp_path, + model_run_with_data_rows, + configured_project_with_one_data_row, + object_predictions, + annotation_import_test_helpers, +): name = str(uuid.uuid4()) file_name = f"{name}.ndjson" file_path = tmp_path / file_name @@ -107,18 +126,22 @@ def test_create_from_local_file(tmp_path, model_run_with_data_rows, for mrdr in model_run_with_data_rows.model_run_data_rows() ] predictions = [ - p for p in object_predictions - if p['dataRow']['id'] in model_run_data_rows + p + for p in object_predictions + if p["dataRow"]["id"] in model_run_data_rows ] with file_path.open("w") as f: parser.dump(predictions, f) - annotation_import, batch, mal_prediction_import = model_run_with_data_rows.upsert_predictions_and_send_to_project( - name=name, - predictions=str(file_path), - project_id=configured_project_with_one_data_row.uid, - priority=5) + annotation_import, batch, mal_prediction_import = ( + model_run_with_data_rows.upsert_predictions_and_send_to_project( + name=name, + predictions=str(file_path), + project_id=configured_project_with_one_data_row.uid, + priority=5, + ) + ) assert annotation_import.model_run_id == model_run_with_data_rows.uid annotation_import.wait_until_done() diff --git a/libs/labelbox/tests/data/annotation_types/classification/test_classification.py b/libs/labelbox/tests/data/annotation_types/classification/test_classification.py index 066cf91bd..801cdb232 100644 --- a/libs/labelbox/tests/data/annotation_types/classification/test_classification.py +++ b/libs/labelbox/tests/data/annotation_types/classification/test_classification.py @@ -1,8 +1,12 @@ import pytest -from labelbox.data.annotation_types import (Checklist, ClassificationAnswer, - Radio, Text, - ClassificationAnnotation) +from labelbox.data.annotation_types import ( + Checklist, + ClassificationAnswer, + Radio, + Text, + ClassificationAnnotation, +) from pydantic import ValidationError @@ -14,18 +18,21 @@ def test_classification_answer(): feature_schema_id = "immunoelectrophoretically" name = "my_feature" confidence = 0.9 - custom_metrics = [{'name': 'metric1', 'value': 2}] - answer = ClassificationAnswer(name=name, - confidence=confidence, - custom_metrics=custom_metrics) + custom_metrics = [{"name": "metric1", "value": 2}] + answer = ClassificationAnswer( + name=name, confidence=confidence, custom_metrics=custom_metrics + ) assert answer.feature_schema_id is None assert answer.name == name assert answer.confidence == confidence - assert [answer.custom_metrics[0].model_dump(exclude_none=True)] == custom_metrics + assert [ + answer.custom_metrics[0].model_dump(exclude_none=True) + ] == custom_metrics - answer = ClassificationAnswer(feature_schema_id=feature_schema_id, - name=name) + answer = ClassificationAnswer( + feature_schema_id=feature_schema_id, name=name + ) assert answer.feature_schema_id == feature_schema_id assert answer.name == name @@ -33,9 +40,13 @@ def test_classification_answer(): def test_classification(): answer = "1234" - classification = ClassificationAnnotation(value=Text(answer=answer), - name="a classification") - assert classification.model_dump(exclude_none=True)['value']['answer'] == answer + classification = ClassificationAnnotation( + value=Text(answer=answer), name="a classification" + ) + assert ( + classification.model_dump(exclude_none=True)["value"]["answer"] + == answer + ) with pytest.raises(ValidationError): ClassificationAnnotation() @@ -48,107 +59,98 @@ def test_subclass(): with pytest.raises(ValidationError): # Should have feature schema info classification = ClassificationAnnotation(value=Text(answer=answer)) - classification = ClassificationAnnotation(value=Text(answer=answer), - name=name) + classification = ClassificationAnnotation( + value=Text(answer=answer), name=name + ) assert classification.model_dump(exclude_none=True) == { - 'name': name, - 'extra': {}, - 'value': { - 'answer': answer, + "name": name, + "extra": {}, + "value": { + "answer": answer, }, } classification = ClassificationAnnotation( value=Text(answer=answer), name=name, - feature_schema_id=feature_schema_id) + feature_schema_id=feature_schema_id, + ) assert classification.model_dump(exclude_none=True) == { - 'feature_schema_id': feature_schema_id, - 'extra': {}, - 'value': { - 'answer': answer, + "feature_schema_id": feature_schema_id, + "extra": {}, + "value": { + "answer": answer, }, - 'name': name, + "name": name, } classification = ClassificationAnnotation( value=Text(answer=answer), feature_schema_id=feature_schema_id, - name=name) + name=name, + ) assert classification.model_dump(exclude_none=True) == { - 'name': name, - 'feature_schema_id': feature_schema_id, - 'extra': {}, - 'value': { - 'answer': answer, + "name": name, + "feature_schema_id": feature_schema_id, + "extra": {}, + "value": { + "answer": answer, }, } def test_radio(): - answer = ClassificationAnswer(name="1", - confidence=0.81, - custom_metrics=[{ - 'name': 'metric1', - 'value': 0.99 - }]) + answer = ClassificationAnswer( + name="1", + confidence=0.81, + custom_metrics=[{"name": "metric1", "value": 0.99}], + ) feature_schema_id = "immunoelectrophoretically" name = "my_feature" with pytest.raises(ValidationError): - classification = ClassificationAnnotation(value=Radio( - answer=answer.name)) + classification = ClassificationAnnotation( + value=Radio(answer=answer.name) + ) with pytest.raises(ValidationError): classification = Radio(answer=[answer]) classification = Radio(answer=answer) assert classification.model_dump(exclude_none=True) == { - 'answer': { - 'name': answer.name, - 'extra': {}, - 'confidence': 0.81, - 'custom_metrics': [{ - 'name': 'metric1', - 'value': 0.99 - }], + "answer": { + "name": answer.name, + "extra": {}, + "confidence": 0.81, + "custom_metrics": [{"name": "metric1", "value": 0.99}], } } classification = ClassificationAnnotation( value=Radio(answer=answer), feature_schema_id=feature_schema_id, name=name, - custom_metrics=[{ - 'name': 'metric1', - 'value': 0.99 - }]) + custom_metrics=[{"name": "metric1", "value": 0.99}], + ) assert classification.model_dump(exclude_none=True) == { - 'name': name, - 'feature_schema_id': feature_schema_id, - 'extra': {}, - 'custom_metrics': [{ - 'name': 'metric1', - 'value': 0.99 - }], - 'value': { - 'answer': { - 'name': answer.name, - 'extra': {}, - 'confidence': 0.81, - 'custom_metrics': [{ - 'name': 'metric1', - 'value': 0.99 - }] + "name": name, + "feature_schema_id": feature_schema_id, + "extra": {}, + "custom_metrics": [{"name": "metric1", "value": 0.99}], + "value": { + "answer": { + "name": answer.name, + "extra": {}, + "confidence": 0.81, + "custom_metrics": [{"name": "metric1", "value": 0.99}], }, }, } def test_checklist(): - answer = ClassificationAnswer(name="1", - confidence=0.99, - custom_metrics=[{ - 'name': 'metric1', - 'value': 2 - }]) + answer = ClassificationAnswer( + name="1", + confidence=0.99, + custom_metrics=[{"name": "metric1", "value": 2}], + ) feature_schema_id = "immunoelectrophoretically" name = "my_feature" @@ -160,15 +162,14 @@ def test_checklist(): classification = Checklist(answer=[answer]) assert classification.model_dump(exclude_none=True) == { - 'answer': [{ - 'name': answer.name, - 'extra': {}, - 'confidence': 0.99, - 'custom_metrics': [{ - 'name': 'metric1', - 'value': 2 - }], - }] + "answer": [ + { + "name": answer.name, + "extra": {}, + "confidence": 0.99, + "custom_metrics": [{"name": "metric1", "value": 2}], + } + ] } classification = ClassificationAnnotation( value=Checklist(answer=[answer]), @@ -176,18 +177,17 @@ def test_checklist(): name=name, ) assert classification.model_dump(exclude_none=True) == { - 'name': name, - 'feature_schema_id': feature_schema_id, - 'extra': {}, - 'value': { - 'answer': [{ - 'name': answer.name, - 'extra': {}, - 'confidence': 0.99, - 'custom_metrics': [{ - 'name': 'metric1', - 'value': 2 - }], - }] + "name": name, + "feature_schema_id": feature_schema_id, + "extra": {}, + "value": { + "answer": [ + { + "name": answer.name, + "extra": {}, + "confidence": 0.99, + "custom_metrics": [{"name": "metric1", "value": 2}], + } + ] }, } diff --git a/libs/labelbox/tests/data/annotation_types/data/test_raster.py b/libs/labelbox/tests/data/annotation_types/data/test_raster.py index 4ce787022..6bc8f2bbf 100644 --- a/libs/labelbox/tests/data/annotation_types/data/test_raster.py +++ b/libs/labelbox/tests/data/annotation_types/data/test_raster.py @@ -42,11 +42,13 @@ def test_ref(): uid = "uid" metadata = [] media_attributes = {} - data = ImageData(im_bytes=b'', - external_id=external_id, - uid=uid, - metadata=metadata, - media_attributes=media_attributes) + data = ImageData( + im_bytes=b"", + external_id=external_id, + uid=uid, + metadata=metadata, + media_attributes=media_attributes, + ) assert data.external_id == external_id assert data.uid == uid assert data.media_attributes == media_attributes diff --git a/libs/labelbox/tests/data/annotation_types/data/test_text.py b/libs/labelbox/tests/data/annotation_types/data/test_text.py index 0af0a37fb..865f93e65 100644 --- a/libs/labelbox/tests/data/annotation_types/data/test_text.py +++ b/libs/labelbox/tests/data/annotation_types/data/test_text.py @@ -15,9 +15,9 @@ def test_text(): text = "hello world" metadata = [] media_attributes = {} - text_data = TextData(text=text, - metadata=metadata, - media_attributes=media_attributes) + text_data = TextData( + text=text, metadata=metadata, media_attributes=media_attributes + ) assert text_data.text == text @@ -31,7 +31,7 @@ def test_url(): def test_file(tmpdir): content = "foo bar baz" file = "hello.txt" - dir = tmpdir.mkdir('data') + dir = tmpdir.mkdir("data") dir.join(file).write(content) text_data = TextData(file_path=os.path.join(dir.strpath, file)) assert len(text_data.value) == len(content) @@ -42,11 +42,13 @@ def test_ref(): uid = "uid" metadata = [] media_attributes = {} - data = TextData(text="hello world", - external_id=external_id, - uid=uid, - metadata=metadata, - media_attributes=media_attributes) + data = TextData( + text="hello world", + external_id=external_id, + uid=uid, + metadata=metadata, + media_attributes=media_attributes, + ) assert data.external_id == external_id assert data.uid == uid assert data.media_attributes == media_attributes diff --git a/libs/labelbox/tests/data/annotation_types/data/test_video.py b/libs/labelbox/tests/data/annotation_types/data/test_video.py index d0e5ed012..5fd77c2c8 100644 --- a/libs/labelbox/tests/data/annotation_types/data/test_video.py +++ b/libs/labelbox/tests/data/annotation_types/data/test_video.py @@ -22,7 +22,7 @@ def test_frames(): def test_file_path(): - path = 'tests/integration/media/cat.mp4' + path = "tests/integration/media/cat.mp4" raster_data = VideoData(file_path=path) with pytest.raises(ValueError): @@ -60,11 +60,13 @@ def test_ref(): } metadata = [] media_attributes = {} - data = VideoData(frames=data, - external_id=external_id, - uid=uid, - metadata=metadata, - media_attributes=media_attributes) + data = VideoData( + frames=data, + external_id=external_id, + uid=uid, + metadata=metadata, + media_attributes=media_attributes, + ) assert data.external_id == external_id assert data.uid == uid assert data.media_attributes == media_attributes diff --git a/libs/labelbox/tests/data/annotation_types/geometry/test_line.py b/libs/labelbox/tests/data/annotation_types/geometry/test_line.py index 10362e728..d6fd1108c 100644 --- a/libs/labelbox/tests/data/annotation_types/geometry/test_line.py +++ b/libs/labelbox/tests/data/annotation_types/geometry/test_line.py @@ -16,7 +16,7 @@ def test_line(): expected = {"coordinates": [points], "type": "MultiLineString"} line = Line(points=[Point(x=x, y=y) for x, y in points]) assert line.geometry == expected - expected['coordinates'] = tuple([tuple([tuple(x) for x in points])]) + expected["coordinates"] = tuple([tuple([tuple(x) for x in points])]) assert line.shapely.__geo_interface__ == expected raster = line.draw(height=32, width=32, thickness=1) diff --git a/libs/labelbox/tests/data/annotation_types/geometry/test_mask.py b/libs/labelbox/tests/data/annotation_types/geometry/test_mask.py index 960e64d9a..6fe8422cf 100644 --- a/libs/labelbox/tests/data/annotation_types/geometry/test_mask.py +++ b/libs/labelbox/tests/data/annotation_types/geometry/test_mask.py @@ -19,45 +19,114 @@ def test_mask(): mask1 = Mask(mask=mask_data, color=(255, 255, 255)) expected1 = { - 'type': - 'MultiPolygon', - 'coordinates': [ - (((0.0, 0.0), (0.0, 1.0), (0.0, 2.0), (0.0, 3.0), (0.0, 4.0), (0.0, - 5.0), - (0.0, 6.0), (0.0, 7.0), (0.0, 8.0), (0.0, 9.0), (0.0, 10.0), - (1.0, 10.0), (2.0, 10.0), (3.0, 10.0), (4.0, 10.0), (5.0, 10.0), - (6.0, 10.0), (7.0, 10.0), (8.0, 10.0), (9.0, 10.0), (10.0, 10.0), - (10.0, 9.0), (10.0, 8.0), (10.0, 7.0), (10.0, 6.0), (10.0, 5.0), - (10.0, 4.0), (10.0, 3.0), (10.0, 2.0), (10.0, 1.0), (10.0, 0.0), - (9.0, 0.0), (8.0, 0.0), (7.0, 0.0), (6.0, 0.0), (5.0, 0.0), - (4.0, 0.0), (3.0, 0.0), (2.0, 0.0), (1.0, 0.0), (0.0, 0.0)),) - ] + "type": "MultiPolygon", + "coordinates": [ + ( + ( + (0.0, 0.0), + (0.0, 1.0), + (0.0, 2.0), + (0.0, 3.0), + (0.0, 4.0), + (0.0, 5.0), + (0.0, 6.0), + (0.0, 7.0), + (0.0, 8.0), + (0.0, 9.0), + (0.0, 10.0), + (1.0, 10.0), + (2.0, 10.0), + (3.0, 10.0), + (4.0, 10.0), + (5.0, 10.0), + (6.0, 10.0), + (7.0, 10.0), + (8.0, 10.0), + (9.0, 10.0), + (10.0, 10.0), + (10.0, 9.0), + (10.0, 8.0), + (10.0, 7.0), + (10.0, 6.0), + (10.0, 5.0), + (10.0, 4.0), + (10.0, 3.0), + (10.0, 2.0), + (10.0, 1.0), + (10.0, 0.0), + (9.0, 0.0), + (8.0, 0.0), + (7.0, 0.0), + (6.0, 0.0), + (5.0, 0.0), + (4.0, 0.0), + (3.0, 0.0), + (2.0, 0.0), + (1.0, 0.0), + (0.0, 0.0), + ), + ) + ], } assert mask1.geometry == expected1 assert mask1.shapely.__geo_interface__ == expected1 mask2 = Mask(mask=mask_data, color=(0, 255, 255)) expected2 = { - 'type': - 'MultiPolygon', - 'coordinates': [ - (((20.0, 20.0), (20.0, 21.0), (20.0, 22.0), (20.0, 23.0), - (20.0, 24.0), (20.0, 25.0), (20.0, 26.0), (20.0, 27.0), - (20.0, 28.0), (20.0, 29.0), (20.0, 30.0), (21.0, 30.0), - (22.0, 30.0), (23.0, 30.0), (24.0, 30.0), (25.0, 30.0), - (26.0, 30.0), (27.0, 30.0), (28.0, 30.0), (29.0, 30.0), - (30.0, 30.0), (30.0, 29.0), (30.0, 28.0), (30.0, 27.0), - (30.0, 26.0), (30.0, 25.0), (30.0, 24.0), (30.0, 23.0), - (30.0, 22.0), (30.0, 21.0), (30.0, 20.0), (29.0, 20.0), - (28.0, 20.0), (27.0, 20.0), (26.0, 20.0), (25.0, 20.0), - (24.0, 20.0), (23.0, 20.0), (22.0, 20.0), (21.0, 20.0), (20.0, - 20.0)),) - ] + "type": "MultiPolygon", + "coordinates": [ + ( + ( + (20.0, 20.0), + (20.0, 21.0), + (20.0, 22.0), + (20.0, 23.0), + (20.0, 24.0), + (20.0, 25.0), + (20.0, 26.0), + (20.0, 27.0), + (20.0, 28.0), + (20.0, 29.0), + (20.0, 30.0), + (21.0, 30.0), + (22.0, 30.0), + (23.0, 30.0), + (24.0, 30.0), + (25.0, 30.0), + (26.0, 30.0), + (27.0, 30.0), + (28.0, 30.0), + (29.0, 30.0), + (30.0, 30.0), + (30.0, 29.0), + (30.0, 28.0), + (30.0, 27.0), + (30.0, 26.0), + (30.0, 25.0), + (30.0, 24.0), + (30.0, 23.0), + (30.0, 22.0), + (30.0, 21.0), + (30.0, 20.0), + (29.0, 20.0), + (28.0, 20.0), + (27.0, 20.0), + (26.0, 20.0), + (25.0, 20.0), + (24.0, 20.0), + (23.0, 20.0), + (22.0, 20.0), + (21.0, 20.0), + (20.0, 20.0), + ), + ) + ], } assert mask2.geometry == expected2 assert mask2.shapely.__geo_interface__ == expected2 - gt_mask = cv2.cvtColor(cv2.imread("tests/data/assets/mask.png"), - cv2.COLOR_BGR2RGB) + gt_mask = cv2.cvtColor( + cv2.imread("tests/data/assets/mask.png"), cv2.COLOR_BGR2RGB + ) assert (gt_mask == mask1.mask.arr).all() assert (gt_mask == mask2.mask.arr).all() @@ -66,13 +135,11 @@ def test_mask(): assert (raster1 != raster2).any() - gt1 = Rectangle(start=Point(x=0, y=0), - end=Point(x=10, y=10)).draw(height=raster1.shape[0], - width=raster1.shape[1], - color=(255, 255, 255)) - gt2 = Rectangle(start=Point(x=20, y=20), - end=Point(x=30, y=30)).draw(height=raster2.shape[0], - width=raster2.shape[1], - color=(0, 255, 255)) + gt1 = Rectangle(start=Point(x=0, y=0), end=Point(x=10, y=10)).draw( + height=raster1.shape[0], width=raster1.shape[1], color=(255, 255, 255) + ) + gt2 = Rectangle(start=Point(x=20, y=20), end=Point(x=30, y=30)).draw( + height=raster2.shape[0], width=raster2.shape[1], color=(0, 255, 255) + ) assert (raster1 == gt1).all() assert (raster2 == gt2).all() diff --git a/libs/labelbox/tests/data/annotation_types/geometry/test_point.py b/libs/labelbox/tests/data/annotation_types/geometry/test_point.py index bca3900d2..335fb6a3a 100644 --- a/libs/labelbox/tests/data/annotation_types/geometry/test_point.py +++ b/libs/labelbox/tests/data/annotation_types/geometry/test_point.py @@ -15,7 +15,7 @@ def test_point(): point = Point(x=0, y=1) expected = {"coordinates": [0, 1], "type": "Point"} assert point.geometry == expected - expected['coordinates'] = tuple(expected['coordinates']) + expected["coordinates"] = tuple(expected["coordinates"]) assert point.shapely.__geo_interface__ == expected raster = point.draw(height=32, width=32, thickness=1) diff --git a/libs/labelbox/tests/data/annotation_types/geometry/test_polygon.py b/libs/labelbox/tests/data/annotation_types/geometry/test_polygon.py index 084349023..0a0bb49b0 100644 --- a/libs/labelbox/tests/data/annotation_types/geometry/test_polygon.py +++ b/libs/labelbox/tests/data/annotation_types/geometry/test_polygon.py @@ -4,6 +4,7 @@ from labelbox.data.annotation_types import Polygon, Point from pydantic import ValidationError + def test_polygon(): with pytest.raises(ValidationError): polygon = Polygon() @@ -14,12 +15,13 @@ def test_polygon(): with pytest.raises(ValidationError): polygon = Polygon(points=[Point(x=0, y=1), Point(x=0, y=1)]) - points = [[0., 1.], [0., 2.], [2., 2.], [2., 0.]] + points = [[0.0, 1.0], [0.0, 2.0], [2.0, 2.0], [2.0, 0.0]] expected = {"coordinates": [points + [points[0]]], "type": "Polygon"} polygon = Polygon(points=[Point(x=x, y=y) for x, y in points]) assert polygon.geometry == expected - expected['coordinates'] = tuple( - [tuple([tuple(x) for x in points + [points[0]]])]) + expected["coordinates"] = tuple( + [tuple([tuple(x) for x in points + [points[0]]])] + ) assert polygon.shapely.__geo_interface__ == expected raster = polygon.draw(10, 10) diff --git a/libs/labelbox/tests/data/annotation_types/geometry/test_rectangle.py b/libs/labelbox/tests/data/annotation_types/geometry/test_rectangle.py index d1d7331d6..54f85eed8 100644 --- a/libs/labelbox/tests/data/annotation_types/geometry/test_rectangle.py +++ b/libs/labelbox/tests/data/annotation_types/geometry/test_rectangle.py @@ -13,12 +13,12 @@ def test_rectangle(): points = [[[0.0, 1.0], [0.0, 10.0], [10.0, 10.0], [10.0, 1.0], [0.0, 1.0]]] expected = {"coordinates": points, "type": "Polygon"} assert rectangle.geometry == expected - expected['coordinates'] = tuple([tuple([tuple(x) for x in points[0]])]) + expected["coordinates"] = tuple([tuple([tuple(x) for x in points[0]])]) assert rectangle.shapely.__geo_interface__ == expected raster = rectangle.draw(height=32, width=32) assert (cv2.imread("tests/data/assets/rectangle.png") == raster).all() - xyhw = Rectangle.from_xyhw(0., 0, 10, 10) - assert xyhw.start == Point(x=0, y=0.) + xyhw = Rectangle.from_xyhw(0.0, 0, 10, 10) + assert xyhw.start == Point(x=0, y=0.0) assert xyhw.end == Point(x=10, y=10.0) diff --git a/libs/labelbox/tests/data/annotation_types/test_annotation.py b/libs/labelbox/tests/data/annotation_types/test_annotation.py index 926d8bc97..8cdeac9ba 100644 --- a/libs/labelbox/tests/data/annotation_types/test_annotation.py +++ b/libs/labelbox/tests/data/annotation_types/test_annotation.py @@ -1,8 +1,13 @@ import pytest -from labelbox.data.annotation_types import (Text, Point, Line, - ClassificationAnnotation, - ObjectAnnotation, TextEntity) +from labelbox.data.annotation_types import ( + Text, + Point, + Line, + ClassificationAnnotation, + ObjectAnnotation, + TextEntity, +) from labelbox.data.annotation_types.video import VideoObjectAnnotation from labelbox.data.annotation_types.geometry.rectangle import Rectangle from labelbox.data.annotation_types.video import VideoClassificationAnnotation @@ -19,7 +24,11 @@ def test_annotation(): value=line, name=name, ) - assert annotation.value.points[0].model_dump() == {'extra': {}, 'x': 1., 'y': 2.} + assert annotation.value.points[0].model_dump() == { + "extra": {}, + "x": 1.0, + "y": 2.0, + } assert annotation.name == name # Check ner @@ -68,25 +77,27 @@ def test_video_annotations(): def test_confidence_for_video_is_not_supported(): with pytest.raises(ConfidenceNotSupportedException): - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=26.5), - end=Point(extra={}, - x=561.0, - y=348.0)), - classifications=[], - frame=24, - keyframe=False, - confidence=0.3434), + ( + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=26.5), + end=Point(extra={}, x=561.0, y=348.0), + ), + classifications=[], + frame=24, + keyframe=False, + confidence=0.3434, + ), + ) def test_confidence_value_range_validation(): diff --git a/libs/labelbox/tests/data/annotation_types/test_collection.py b/libs/labelbox/tests/data/annotation_types/test_collection.py index 34b868162..16c9d699f 100644 --- a/libs/labelbox/tests/data/annotation_types/test_collection.py +++ b/libs/labelbox/tests/data/annotation_types/test_collection.py @@ -4,9 +4,16 @@ import numpy as np import pytest -from labelbox.data.annotation_types import (LabelGenerator, ObjectAnnotation, - ImageData, MaskData, Line, Mask, - Point, Label) +from labelbox.data.annotation_types import ( + LabelGenerator, + ObjectAnnotation, + ImageData, + MaskData, + Line, + Mask, + Point, + Label, +) from labelbox import OntologyBuilder, Tool @@ -17,7 +24,6 @@ def list_of_labels(): @pytest.fixture def signer(): - def get_signer(uuid): return lambda x: uuid @@ -25,7 +31,6 @@ def get_signer(uuid): class FakeDataset: - def __init__(self): self.uid = "ckrb4tgm51xl10ybc7lv9ghm7" self.exports = [] @@ -38,9 +43,12 @@ def create_data_row(self, row_data, external_id=None): def create_data_rows(self, args): for arg in args: self.exports.append( - SimpleNamespace(row_data=arg['row_data'], - external_id=arg['external_id'], - uid=self.uid)) + SimpleNamespace( + row_data=arg["row_data"], + external_id=arg["external_id"], + uid=self.uid, + ) + ) return self def wait_till_done(self): @@ -72,23 +80,26 @@ def test_adding_schema_ids(): data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), annotations=[ ObjectAnnotation( - value=Line( - points=[Point(x=1, y=2), Point(x=2, y=2)]), + value=Line(points=[Point(x=1, y=2), Point(x=2, y=2)]), name=name, ) - ]) + ], + ) feature_schema_id = "expected_id" - ontology = OntologyBuilder(tools=[ - Tool(Tool.Type.LINE, name=name, feature_schema_id=feature_schema_id) - ]) + ontology = OntologyBuilder( + tools=[ + Tool(Tool.Type.LINE, name=name, feature_schema_id=feature_schema_id) + ] + ) generator = LabelGenerator([label]).assign_feature_schema_ids(ontology) assert next(generator).annotations[0].feature_schema_id == feature_schema_id def test_adding_urls(signer): - label = Label(data=ImageData(arr=np.random.random((32, 32, - 3)).astype(np.uint8)), - annotations=[]) + label = Label( + data=ImageData(arr=np.random.random((32, 32, 3)).astype(np.uint8)), + annotations=[], + ) uuid = str(uuid4()) generator = LabelGenerator([label]).add_url_to_data(signer(uuid)) assert label.data.url != uuid @@ -98,9 +109,10 @@ def test_adding_urls(signer): def test_adding_to_dataset(signer): dataset = FakeDataset() - label = Label(data=ImageData(arr=np.random.random((32, 32, - 3)).astype(np.uint8)), - annotations=[]) + label = Label( + data=ImageData(arr=np.random.random((32, 32, 3)).astype(np.uint8)), + annotations=[], + ) uuid = str(uuid4()) generator = LabelGenerator([label]).add_to_dataset(dataset, signer(uuid)) assert label.data.url != uuid @@ -115,12 +127,17 @@ def test_adding_to_masks(signer): label = Label( data=ImageData(arr=np.random.random((32, 32, 3)).astype(np.uint8)), annotations=[ - ObjectAnnotation(name="1234", - value=Mask(mask=MaskData( - arr=np.random.random((32, 32, - 3)).astype(np.uint8)), - color=[255, 255, 255])) - ]) + ObjectAnnotation( + name="1234", + value=Mask( + mask=MaskData( + arr=np.random.random((32, 32, 3)).astype(np.uint8) + ), + color=[255, 255, 255], + ), + ) + ], + ) uuid = str(uuid4()) generator = LabelGenerator([label]).add_url_to_masks(signer(uuid)) assert label.annotations[0].value.mask.url != uuid diff --git a/libs/labelbox/tests/data/annotation_types/test_label.py b/libs/labelbox/tests/data/annotation_types/test_label.py index f0957fcee..5bdfb6bde 100644 --- a/libs/labelbox/tests/data/annotation_types/test_label.py +++ b/libs/labelbox/tests/data/annotation_types/test_label.py @@ -2,12 +2,24 @@ import numpy as np import labelbox.types as lb_types -from labelbox import OntologyBuilder, Tool, Classification as OClassification, Option -from labelbox.data.annotation_types import (ClassificationAnswer, Radio, Text, - ClassificationAnnotation, - PromptText, - ObjectAnnotation, Point, Line, - ImageData, Label) +from labelbox import ( + OntologyBuilder, + Tool, + Classification as OClassification, + Option, +) +from labelbox.data.annotation_types import ( + ClassificationAnswer, + Radio, + Text, + ClassificationAnnotation, + PromptText, + ObjectAnnotation, + Point, + Line, + ImageData, + Label, +) import pytest @@ -17,15 +29,17 @@ def test_schema_assignment_geometry(): data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), annotations=[ ObjectAnnotation( - value=Line( - points=[Point(x=1, y=2), Point(x=2, y=2)]), + value=Line(points=[Point(x=1, y=2), Point(x=2, y=2)]), name=name, ) - ]) + ], + ) feature_schema_id = "expected_id" - ontology = OntologyBuilder(tools=[ - Tool(Tool.Type.LINE, name=name, feature_schema_id=feature_schema_id) - ]) + ontology = OntologyBuilder( + tools=[ + Tool(Tool.Type.LINE, name=name, feature_schema_id=feature_schema_id) + ] + ) label.assign_feature_schema_ids(ontology) assert label.annotations[0].feature_schema_id == feature_schema_id @@ -36,38 +50,47 @@ def test_schema_assignment_classification(): text_name = "text_name" option_name = "my_option" - label = Label(data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), - annotations=[ - ClassificationAnnotation(value=Radio( - answer=ClassificationAnswer(name=option_name)), - name=radio_name), - ClassificationAnnotation(value=Text(answer="some text"), - name=text_name) - ]) + label = Label( + data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), + annotations=[ + ClassificationAnnotation( + value=Radio(answer=ClassificationAnswer(name=option_name)), + name=radio_name, + ), + ClassificationAnnotation( + value=Text(answer="some text"), name=text_name + ), + ], + ) radio_schema_id = "radio_schema_id" text_schema_id = "text_schema_id" option_schema_id = "option_schema_id" ontology = OntologyBuilder( tools=[], classifications=[ - OClassification(class_type=OClassification.Type.RADIO, - name=radio_name, - feature_schema_id=radio_schema_id, - options=[ - Option(value=option_name, - feature_schema_id=option_schema_id) - ]), + OClassification( + class_type=OClassification.Type.RADIO, + name=radio_name, + feature_schema_id=radio_schema_id, + options=[ + Option( + value=option_name, feature_schema_id=option_schema_id + ) + ], + ), OClassification( class_type=OClassification.Type.TEXT, name=text_name, feature_schema_id=text_schema_id, - ) - ]) + ), + ], + ) label.assign_feature_schema_ids(ontology) assert label.annotations[0].feature_schema_id == radio_schema_id assert label.annotations[1].feature_schema_id == text_schema_id - assert label.annotations[ - 0].value.answer.feature_schema_id == option_schema_id + assert ( + label.annotations[0].value.answer.feature_schema_id == option_schema_id + ) def test_schema_assignment_subclass(): @@ -81,34 +104,48 @@ def test_schema_assignment_subclass(): label = Label( data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), annotations=[ - ObjectAnnotation(value=Line( - points=[Point(x=1, y=2), Point(x=2, y=2)]), - name=name, - classifications=[classification]) - ]) + ObjectAnnotation( + value=Line(points=[Point(x=1, y=2), Point(x=2, y=2)]), + name=name, + classifications=[classification], + ) + ], + ) feature_schema_id = "expected_id" classification_schema_id = "classification_id" option_schema_id = "option_schema_id" - ontology = OntologyBuilder(tools=[ - Tool(Tool.Type.LINE, - name=name, - feature_schema_id=feature_schema_id, - classifications=[ - OClassification(class_type=OClassification.Type.RADIO, - name=radio_name, - feature_schema_id=classification_schema_id, - options=[ - Option(value=option_name, - feature_schema_id=option_schema_id) - ]) - ]) - ]) + ontology = OntologyBuilder( + tools=[ + Tool( + Tool.Type.LINE, + name=name, + feature_schema_id=feature_schema_id, + classifications=[ + OClassification( + class_type=OClassification.Type.RADIO, + name=radio_name, + feature_schema_id=classification_schema_id, + options=[ + Option( + value=option_name, + feature_schema_id=option_schema_id, + ) + ], + ) + ], + ) + ] + ) label.assign_feature_schema_ids(ontology) assert label.annotations[0].feature_schema_id == feature_schema_id - assert label.annotations[0].classifications[ - 0].feature_schema_id == classification_schema_id - assert label.annotations[0].classifications[ - 0].value.answer.feature_schema_id == option_schema_id + assert ( + label.annotations[0].classifications[0].feature_schema_id + == classification_schema_id + ) + assert ( + label.annotations[0].classifications[0].value.answer.feature_schema_id + == option_schema_id + ) def test_highly_nested(): @@ -121,92 +158,117 @@ def test_highly_nested(): name=radio_name, value=Radio(answer=ClassificationAnswer(name=option_name)), classifications=[ - ClassificationAnnotation(value=Radio(answer=ClassificationAnswer( - name=nested_option_name)), - name=nested_name) - ]) + ClassificationAnnotation( + value=Radio( + answer=ClassificationAnswer(name=nested_option_name) + ), + name=nested_name, + ) + ], + ) label = Label( data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), annotations=[ - ObjectAnnotation(value=Line( - points=[Point(x=1, y=2), Point(x=2, y=2)]), - name=name, - classifications=[classification]) - ]) + ObjectAnnotation( + value=Line(points=[Point(x=1, y=2), Point(x=2, y=2)]), + name=name, + classifications=[classification], + ) + ], + ) feature_schema_id = "expected_id" classification_schema_id = "classification_id" nested_classification_schema_id = "nested_classification_schema_id" option_schema_id = "option_schema_id" - ontology = OntologyBuilder(tools=[ - Tool(Tool.Type.LINE, - name=name, - feature_schema_id=feature_schema_id, - classifications=[ - OClassification( - class_type=OClassification.Type.RADIO, - name=radio_name, - feature_schema_id=classification_schema_id, - options=[ - Option(value=option_name, + ontology = OntologyBuilder( + tools=[ + Tool( + Tool.Type.LINE, + name=name, + feature_schema_id=feature_schema_id, + classifications=[ + OClassification( + class_type=OClassification.Type.RADIO, + name=radio_name, + feature_schema_id=classification_schema_id, + options=[ + Option( + value=option_name, feature_schema_id=option_schema_id, options=[ OClassification( class_type=OClassification.Type.RADIO, name=nested_name, - feature_schema_id= - nested_classification_schema_id, + feature_schema_id=nested_classification_schema_id, options=[ Option( value=nested_option_name, - feature_schema_id= - nested_classification_schema_id) - ]) - ]) - ]) - ]) - ]) + feature_schema_id=nested_classification_schema_id, + ) + ], + ) + ], + ) + ], + ) + ], + ) + ] + ) label.assign_feature_schema_ids(ontology) assert label.annotations[0].feature_schema_id == feature_schema_id - assert label.annotations[0].classifications[ - 0].feature_schema_id == classification_schema_id - assert label.annotations[0].classifications[ - 0].value.answer.feature_schema_id == option_schema_id + assert ( + label.annotations[0].classifications[0].feature_schema_id + == classification_schema_id + ) + assert ( + label.annotations[0].classifications[0].value.answer.feature_schema_id + == option_schema_id + ) def test_schema_assignment_confidence(): name = "line_feature" - label = Label(data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), - annotations=[ - ObjectAnnotation(value=Line( - points=[Point(x=1, y=2), - Point(x=2, y=2)],), - name=name, - confidence=0.914) - ]) + label = Label( + data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), + annotations=[ + ObjectAnnotation( + value=Line( + points=[Point(x=1, y=2), Point(x=2, y=2)], + ), + name=name, + confidence=0.914, + ) + ], + ) assert label.annotations[0].confidence == 0.914 def test_initialize_label_no_coercion(): - global_key = 'global-key' + global_key = "global-key" ner_annotation = lb_types.ObjectAnnotation( name="ner", - value=lb_types.ConversationEntity(start=0, end=8, message_id="4")) - label = Label(data=lb_types.ConversationData(global_key=global_key), - annotations=[ner_annotation]) + value=lb_types.ConversationEntity(start=0, end=8, message_id="4"), + ) + label = Label( + data=lb_types.ConversationData(global_key=global_key), + annotations=[ner_annotation], + ) assert isinstance(label.data, lb_types.ConversationData) assert label.data.global_key == global_key + def test_prompt_classification_validation(): - global_key = 'global-key' + global_key = "global-key" prompt_text = lb_types.PromptClassificationAnnotation( - name="prompt text", - value=PromptText(answer="test") + name="prompt text", value=PromptText(answer="test") ) prompt_text_2 = lb_types.PromptClassificationAnnotation( - name="prompt text", - value=PromptText(answer="test") + name="prompt text", value=PromptText(answer="test") ) with pytest.raises(TypeError) as e_info: - label = Label(data={"global_key": global_key}, - annotations=[prompt_text, prompt_text_2]) + label = Label( + data={"global_key": global_key}, + annotations=[prompt_text, prompt_text_2], + ) diff --git a/libs/labelbox/tests/data/annotation_types/test_metrics.py b/libs/labelbox/tests/data/annotation_types/test_metrics.py index d2e488109..94c9521a5 100644 --- a/libs/labelbox/tests/data/annotation_types/test_metrics.py +++ b/libs/labelbox/tests/data/annotation_types/test_metrics.py @@ -1,7 +1,13 @@ import pytest -from labelbox.data.annotation_types.metrics import ConfusionMatrixAggregation, ScalarMetricAggregation -from labelbox.data.annotation_types.metrics import ConfusionMatrixMetric, ScalarMetric +from labelbox.data.annotation_types.metrics import ( + ConfusionMatrixAggregation, + ScalarMetricAggregation, +) +from labelbox.data.annotation_types.metrics import ( + ConfusionMatrixMetric, + ScalarMetric, +) from labelbox.data.annotation_types import ScalarMetric, Label, ImageData from labelbox.data.annotation_types.metrics.scalar import RESERVED_METRIC_NAMES from pydantic import ValidationError @@ -12,19 +18,22 @@ def test_legacy_scalar_metric(): metric = ScalarMetric(value=value) assert metric.value == value - label = Label(data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), - annotations=[metric]) + label = Label( + data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), annotations=[metric] + ) expected = { - 'data': { - 'uid': 'ckrmd9q8g000009mg6vej7hzg', + "data": { + "uid": "ckrmd9q8g000009mg6vej7hzg", }, - 'annotations': [{ - 'aggregation': ScalarMetricAggregation.ARITHMETIC_MEAN, - 'value': 10.0, - 'extra': {}, - }], - 'extra': {}, - 'is_benchmark_reference': False + "annotations": [ + { + "aggregation": ScalarMetricAggregation.ARITHMETIC_MEAN, + "value": 10.0, + "extra": {}, + } + ], + "extra": {}, + "is_benchmark_reference": False, } assert label.model_dump(exclude_none=True) == expected @@ -32,100 +41,118 @@ def test_legacy_scalar_metric(): # TODO: Test with confidence -@pytest.mark.parametrize('feature_name,subclass_name,aggregation,value', [ - ("cat", "orange", ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5), - ("cat", None, ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5), - (None, None, ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5), - (None, None, None, 0.5), - ("cat", "orange", ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5), - ("cat", None, ScalarMetricAggregation.HARMONIC_MEAN, 0.5), - (None, None, ScalarMetricAggregation.GEOMETRIC_MEAN, 0.5), - (None, None, ScalarMetricAggregation.SUM, 0.5), - ("cat", "orange", ScalarMetricAggregation.ARITHMETIC_MEAN, { - 0.1: 0.2, - 0.3: 0.5, - 0.4: 0.8 - }), -]) +@pytest.mark.parametrize( + "feature_name,subclass_name,aggregation,value", + [ + ("cat", "orange", ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5), + ("cat", None, ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5), + (None, None, ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5), + (None, None, None, 0.5), + ("cat", "orange", ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5), + ("cat", None, ScalarMetricAggregation.HARMONIC_MEAN, 0.5), + (None, None, ScalarMetricAggregation.GEOMETRIC_MEAN, 0.5), + (None, None, ScalarMetricAggregation.SUM, 0.5), + ( + "cat", + "orange", + ScalarMetricAggregation.ARITHMETIC_MEAN, + {0.1: 0.2, 0.3: 0.5, 0.4: 0.8}, + ), + ], +) def test_custom_scalar_metric(feature_name, subclass_name, aggregation, value): - kwargs = {'aggregation': aggregation} if aggregation is not None else {} - metric = ScalarMetric(metric_name="custom_iou", - value=value, - feature_name=feature_name, - subclass_name=subclass_name, - **kwargs) + kwargs = {"aggregation": aggregation} if aggregation is not None else {} + metric = ScalarMetric( + metric_name="custom_iou", + value=value, + feature_name=feature_name, + subclass_name=subclass_name, + **kwargs, + ) assert metric.value == value - label = Label(data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), - annotations=[metric]) + label = Label( + data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), annotations=[metric] + ) expected = { - 'data': { - 'uid': 'ckrmd9q8g000009mg6vej7hzg', + "data": { + "uid": "ckrmd9q8g000009mg6vej7hzg", }, - 'annotations': [{ - 'value': - value, - 'metric_name': - 'custom_iou', - **({ - 'feature_name': feature_name - } if feature_name else {}), - **({ - 'subclass_name': subclass_name - } if subclass_name else {}), 'aggregation': - aggregation or ScalarMetricAggregation.ARITHMETIC_MEAN, - 'extra': {} - }], - 'extra': {}, - 'is_benchmark_reference': False + "annotations": [ + { + "value": value, + "metric_name": "custom_iou", + **({"feature_name": feature_name} if feature_name else {}), + **({"subclass_name": subclass_name} if subclass_name else {}), + "aggregation": aggregation + or ScalarMetricAggregation.ARITHMETIC_MEAN, + "extra": {}, + } + ], + "extra": {}, + "is_benchmark_reference": False, } assert label.model_dump(exclude_none=True) == expected -@pytest.mark.parametrize('feature_name,subclass_name,aggregation,value', [ - ("cat", "orange", ConfusionMatrixAggregation.CONFUSION_MATRIX, - (0, 1, 2, 3)), - ("cat", None, ConfusionMatrixAggregation.CONFUSION_MATRIX, (0, 1, 2, 3)), - (None, None, ConfusionMatrixAggregation.CONFUSION_MATRIX, (0, 1, 2, 3)), - (None, None, None, (0, 1, 2, 3)), - ("cat", "orange", ConfusionMatrixAggregation.CONFUSION_MATRIX, { - 0.1: (0, 1, 2, 3), - 0.3: (0, 1, 2, 3), - 0.4: (0, 1, 2, 3) - }), -]) -def test_custom_confusison_matrix_metric(feature_name, subclass_name, - aggregation, value): - kwargs = {'aggregation': aggregation} if aggregation is not None else {} - metric = ConfusionMatrixMetric(metric_name="confusion_matrix_50_pct_iou", - value=value, - feature_name=feature_name, - subclass_name=subclass_name, - **kwargs) +@pytest.mark.parametrize( + "feature_name,subclass_name,aggregation,value", + [ + ( + "cat", + "orange", + ConfusionMatrixAggregation.CONFUSION_MATRIX, + (0, 1, 2, 3), + ), + ( + "cat", + None, + ConfusionMatrixAggregation.CONFUSION_MATRIX, + (0, 1, 2, 3), + ), + (None, None, ConfusionMatrixAggregation.CONFUSION_MATRIX, (0, 1, 2, 3)), + (None, None, None, (0, 1, 2, 3)), + ( + "cat", + "orange", + ConfusionMatrixAggregation.CONFUSION_MATRIX, + {0.1: (0, 1, 2, 3), 0.3: (0, 1, 2, 3), 0.4: (0, 1, 2, 3)}, + ), + ], +) +def test_custom_confusison_matrix_metric( + feature_name, subclass_name, aggregation, value +): + kwargs = {"aggregation": aggregation} if aggregation is not None else {} + metric = ConfusionMatrixMetric( + metric_name="confusion_matrix_50_pct_iou", + value=value, + feature_name=feature_name, + subclass_name=subclass_name, + **kwargs, + ) assert metric.value == value - label = Label(data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), - annotations=[metric]) + label = Label( + data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), annotations=[metric] + ) expected = { - 'data': { - 'uid': 'ckrmd9q8g000009mg6vej7hzg', + "data": { + "uid": "ckrmd9q8g000009mg6vej7hzg", }, - 'annotations': [{ - 'value': - value, - 'metric_name': - 'confusion_matrix_50_pct_iou', - **({ - 'feature_name': feature_name - } if feature_name else {}), - **({ - 'subclass_name': subclass_name - } if subclass_name else {}), 'aggregation': - aggregation or ConfusionMatrixAggregation.CONFUSION_MATRIX, - 'extra': {} - }], - 'extra': {}, - 'is_benchmark_reference': False + "annotations": [ + { + "value": value, + "metric_name": "confusion_matrix_50_pct_iou", + **({"feature_name": feature_name} if feature_name else {}), + **({"subclass_name": subclass_name} if subclass_name else {}), + "aggregation": aggregation + or ConfusionMatrixAggregation.CONFUSION_MATRIX, + "extra": {}, + } + ], + "extra": {}, + "is_benchmark_reference": False, } assert label.model_dump(exclude_none=True) == expected @@ -141,11 +168,14 @@ def test_invalid_aggregations(): metric = ScalarMetric( metric_name="invalid aggregation", value=0.1, - aggregation=ConfusionMatrixAggregation.CONFUSION_MATRIX) + aggregation=ConfusionMatrixAggregation.CONFUSION_MATRIX, + ) with pytest.raises(ValidationError) as exc_info: - metric = ConfusionMatrixMetric(metric_name="invalid aggregation", - value=[0, 1, 2, 3], - aggregation=ScalarMetricAggregation.SUM) + metric = ConfusionMatrixMetric( + metric_name="invalid aggregation", + value=[0, 1, 2, 3], + aggregation=ScalarMetricAggregation.SUM, + ) def test_invalid_number_of_confidence_scores(): @@ -153,17 +183,21 @@ def test_invalid_number_of_confidence_scores(): metric = ScalarMetric(metric_name="too few scores", value={0.1: 0.1}) assert "Number of confidence scores must be greater" in str(exc_info.value) with pytest.raises(ValidationError) as exc_info: - metric = ConfusionMatrixMetric(metric_name="too few scores", - value={0.1: [0, 1, 2, 3]}) + metric = ConfusionMatrixMetric( + metric_name="too few scores", value={0.1: [0, 1, 2, 3]} + ) assert "Number of confidence scores must be greater" in str(exc_info.value) with pytest.raises(ValidationError) as exc_info: - metric = ScalarMetric(metric_name="too many scores", - value={i / 20.: 0.1 for i in range(20)}) + metric = ScalarMetric( + metric_name="too many scores", + value={i / 20.0: 0.1 for i in range(20)}, + ) assert "Number of confidence scores must be greater" in str(exc_info.value) with pytest.raises(ValidationError) as exc_info: metric = ConfusionMatrixMetric( metric_name="too many scores", - value={i / 20.: [0, 1, 2, 3] for i in range(20)}) + value={i / 20.0: [0, 1, 2, 3] for i in range(20)}, + ) assert "Number of confidence scores must be greater" in str(exc_info.value) @@ -171,4 +205,4 @@ def test_invalid_number_of_confidence_scores(): def test_reserved_names(metric_name: str): with pytest.raises(ValidationError) as exc_info: ScalarMetric(metric_name=metric_name, value=0.5) - assert 'is a reserved metric name' in exc_info.value.errors()[0]['msg'] + assert "is a reserved metric name" in exc_info.value.errors()[0]["msg"] diff --git a/libs/labelbox/tests/data/annotation_types/test_ner.py b/libs/labelbox/tests/data/annotation_types/test_ner.py index 9619689b1..32f40472e 100644 --- a/libs/labelbox/tests/data/annotation_types/test_ner.py +++ b/libs/labelbox/tests/data/annotation_types/test_ner.py @@ -1,5 +1,11 @@ -from labelbox.data.annotation_types import TextEntity, DocumentEntity, DocumentTextSelection -from labelbox.data.annotation_types.ner.conversation_entity import ConversationEntity +from labelbox.data.annotation_types import ( + TextEntity, + DocumentEntity, + DocumentTextSelection, +) +from labelbox.data.annotation_types.ner.conversation_entity import ( + ConversationEntity, +) def test_ner(): @@ -11,9 +17,11 @@ def test_ner(): def test_document_entity(): - document_entity = DocumentEntity(text_selections=[ - DocumentTextSelection(token_ids=["1", "2"], group_id="1", page=1) - ]) + document_entity = DocumentEntity( + text_selections=[ + DocumentTextSelection(token_ids=["1", "2"], group_id="1", page=1) + ] + ) assert document_entity.text_selections[0].token_ids == ["1", "2"] assert document_entity.text_selections[0].group_id == "1" diff --git a/libs/labelbox/tests/data/annotation_types/test_tiled_image.py b/libs/labelbox/tests/data/annotation_types/test_tiled_image.py index aea6587f6..46f2383d6 100644 --- a/libs/labelbox/tests/data/annotation_types/test_tiled_image.py +++ b/libs/labelbox/tests/data/annotation_types/test_tiled_image.py @@ -3,10 +3,13 @@ from labelbox.data.annotation_types.geometry.point import Point from labelbox.data.annotation_types.geometry.line import Line from labelbox.data.annotation_types.geometry.rectangle import Rectangle -from labelbox.data.annotation_types.data.tiled_image import (EPSG, TiledBounds, - TileLayer, - TiledImageData, - EPSGTransformer) +from labelbox.data.annotation_types.data.tiled_image import ( + EPSG, + TiledBounds, + TileLayer, + TiledImageData, + EPSGTransformer, +) from pydantic import ValidationError @@ -29,21 +32,26 @@ def test_tiled_bounds(epsg): def test_tiled_bounds_same(epsg): single_bound = Point(x=0, y=0) with pytest.raises(ValidationError): - tiled_bounds = TiledBounds(epsg=epsg, - bounds=[single_bound, single_bound]) + tiled_bounds = TiledBounds( + epsg=epsg, bounds=[single_bound, single_bound] + ) def test_create_tiled_image_data(): bounds_points = [Point(x=0, y=0), Point(x=5, y=5)] - url = "https://labelbox.s3-us-west-2.amazonaws.com/pathology/{z}/{x}/{y}.png" + url = ( + "https://labelbox.s3-us-west-2.amazonaws.com/pathology/{z}/{x}/{y}.png" + ) zoom_levels = (1, 10) tile_layer = TileLayer(url=url, name="slippy map tile") tile_bounds = TiledBounds(epsg=EPSG.EPSG4326, bounds=bounds_points) - tiled_image_data = TiledImageData(tile_layer=tile_layer, - tile_bounds=tile_bounds, - zoom_levels=zoom_levels, - version=2) + tiled_image_data = TiledImageData( + tile_layer=tile_layer, + tile_bounds=tile_bounds, + zoom_levels=zoom_levels, + version=2, + ) assert isinstance(tiled_image_data, TiledImageData) assert tiled_image_data.tile_bounds.bounds == bounds_points assert tiled_image_data.tile_layer.url == url @@ -53,20 +61,24 @@ def test_create_tiled_image_data(): def test_epsg_point_projections(): zoom = 4 - bounds_simple = TiledBounds(epsg=EPSG.SIMPLEPIXEL, - bounds=[Point(x=0, y=0), - Point(x=256, y=256)]) - - bounds_3857 = TiledBounds(epsg=EPSG.EPSG3857, - bounds=[ - Point(x=-104.150390625, y=30.789036751261136), - Point(x=-81.8701171875, y=45.920587344733654) - ]) - bounds_4326 = TiledBounds(epsg=EPSG.EPSG4326, - bounds=[ - Point(x=-104.150390625, y=30.789036751261136), - Point(x=-81.8701171875, y=45.920587344733654) - ]) + bounds_simple = TiledBounds( + epsg=EPSG.SIMPLEPIXEL, bounds=[Point(x=0, y=0), Point(x=256, y=256)] + ) + + bounds_3857 = TiledBounds( + epsg=EPSG.EPSG3857, + bounds=[ + Point(x=-104.150390625, y=30.789036751261136), + Point(x=-81.8701171875, y=45.920587344733654), + ], + ) + bounds_4326 = TiledBounds( + epsg=EPSG.EPSG4326, + bounds=[ + Point(x=-104.150390625, y=30.789036751261136), + Point(x=-81.8701171875, y=45.920587344733654), + ], + ) point = Point(x=-11016716.012685884, y=5312679.21393289) point_two = Point(x=-12016716.012685884, y=5212679.21393289) @@ -82,7 +94,8 @@ def test_epsg_point_projections(): src_epsg=EPSG.EPSG3857, pixel_bounds=bounds_simple, geo_bounds=bounds_3857, - zoom=zoom) + zoom=zoom, + ) transformer_3857_4326 = EPSGTransformer.create_geo_to_geo_transformer( src_epsg=EPSG.EPSG3857, tgt_epsg=EPSG.EPSG4326, @@ -91,7 +104,8 @@ def test_epsg_point_projections(): src_epsg=EPSG.EPSG4326, pixel_bounds=bounds_simple, geo_bounds=bounds_4326, - zoom=zoom) + zoom=zoom, + ) for shape in shapes_to_test: shape_simple = transformer_3857_simple(shape=shape) diff --git a/libs/labelbox/tests/data/annotation_types/test_video.py b/libs/labelbox/tests/data/annotation_types/test_video.py index f61dc7ec7..4b92e161d 100644 --- a/libs/labelbox/tests/data/annotation_types/test_video.py +++ b/libs/labelbox/tests/data/annotation_types/test_video.py @@ -2,18 +2,19 @@ def test_mask_frame(): - mask_frame = lb_types.MaskFrame(index=1, - instance_uri="http://path/to/frame.png") + mask_frame = lb_types.MaskFrame( + index=1, instance_uri="http://path/to/frame.png" + ) assert mask_frame.model_dump(by_alias=True) == { - 'index': 1, - 'imBytes': None, - 'instanceURI': 'http://path/to/frame.png' + "index": 1, + "imBytes": None, + "instanceURI": "http://path/to/frame.png", } def test_mask_instance(): mask_instance = lb_types.MaskInstance(color_rgb=(0, 0, 255), name="mask1") assert mask_instance.model_dump(by_alias=True, exclude_none=True) == { - 'colorRGB': (0, 0, 255), - 'name': 'mask1' + "colorRGB": (0, 0, 255), + "name": "mask1", } diff --git a/libs/labelbox/tests/data/conftest.py b/libs/labelbox/tests/data/conftest.py index 07f3460b8..aa1379407 100644 --- a/libs/labelbox/tests/data/conftest.py +++ b/libs/labelbox/tests/data/conftest.py @@ -1,6 +1,11 @@ import pytest -from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnnotation, ClassificationAnswer, Radio +from labelbox.data.annotation_types.classification.classification import ( + Checklist, + ClassificationAnnotation, + ClassificationAnswer, + Radio, +) from labelbox.data.annotation_types.geometry.point import Point from labelbox.data.annotation_types.geometry.rectangle import Rectangle from labelbox.data.annotation_types.video import VideoObjectAnnotation @@ -20,21 +25,30 @@ def bbox_video_annotation_objects(): ), classifications=[ ClassificationAnnotation( - name='nested', - value=Radio(answer=ClassificationAnswer( - name='radio_option_1', - classifications=[ - ClassificationAnnotation( - name='nested_checkbox', - value=Checklist(answer=[ - ClassificationAnswer( - name='nested_checkbox_option_1'), - ClassificationAnswer( - name='nested_checkbox_option_2') - ])) - ])), + name="nested", + value=Radio( + answer=ClassificationAnswer( + name="radio_option_1", + classifications=[ + ClassificationAnnotation( + name="nested_checkbox", + value=Checklist( + answer=[ + ClassificationAnswer( + name="nested_checkbox_option_1" + ), + ClassificationAnswer( + name="nested_checkbox_option_2" + ), + ] + ), + ) + ], + ) + ), ) - ]), + ], + ), VideoObjectAnnotation( name="bbox", keyframe=True, @@ -43,7 +57,8 @@ def bbox_video_annotation_objects(): value=Rectangle( start=Point(x=186.0, y=98.0), # Top left end=Point(x=490.0, y=341.0), # Bottom right - )) + ), + ), ] return bbox_annotation diff --git a/libs/labelbox/tests/data/export/conftest.py b/libs/labelbox/tests/data/export/conftest.py index 104ee41dc..0836c2b9e 100644 --- a/libs/labelbox/tests/data/export/conftest.py +++ b/libs/labelbox/tests/data/export/conftest.py @@ -10,225 +10,196 @@ @pytest.fixture def ontology(): bbox_tool_with_nested_text = { - 'required': - False, - 'name': - 'bbox_tool_with_nested_text', - 'tool': - 'rectangle', - 'color': - '#a23030', - 'classifications': [{ - 'required': - False, - 'instructions': - 'nested', - 'name': - 'nested', - 'type': - 'radio', - 'options': [{ - 'label': - 'radio_option_1', - 'value': - 'radio_value_1', - 'options': [{ - 'required': - False, - 'instructions': - 'nested_checkbox', - 'name': - 'nested_checkbox', - 'type': - 'checklist', - 'options': [{ - 'label': 'nested_checkbox_option_1', - 'value': 'nested_checkbox_value_1', - 'options': [] - }, { - 'label': 'nested_checkbox_option_2', - 'value': 'nested_checkbox_value_2' - }] - }, { - 'required': False, - 'instructions': 'nested_text', - 'name': 'nested_text', - 'type': 'text', - 'options': [] - }] - },] - }] + "required": False, + "name": "bbox_tool_with_nested_text", + "tool": "rectangle", + "color": "#a23030", + "classifications": [ + { + "required": False, + "instructions": "nested", + "name": "nested", + "type": "radio", + "options": [ + { + "label": "radio_option_1", + "value": "radio_value_1", + "options": [ + { + "required": False, + "instructions": "nested_checkbox", + "name": "nested_checkbox", + "type": "checklist", + "options": [ + { + "label": "nested_checkbox_option_1", + "value": "nested_checkbox_value_1", + "options": [], + }, + { + "label": "nested_checkbox_option_2", + "value": "nested_checkbox_value_2", + }, + ], + }, + { + "required": False, + "instructions": "nested_text", + "name": "nested_text", + "type": "text", + "options": [], + }, + ], + }, + ], + } + ], } bbox_tool = { - 'required': - False, - 'name': - 'bbox', - 'tool': - 'rectangle', - 'color': - '#a23030', - 'classifications': [{ - 'required': - False, - 'instructions': - 'nested', - 'name': - 'nested', - 'type': - 'radio', - 'options': [{ - 'label': - 'radio_option_1', - 'value': - 'radio_value_1', - 'options': [{ - 'required': - False, - 'instructions': - 'nested_checkbox', - 'name': - 'nested_checkbox', - 'type': - 'checklist', - 'options': [{ - 'label': 'nested_checkbox_option_1', - 'value': 'nested_checkbox_value_1', - 'options': [] - }, { - 'label': 'nested_checkbox_option_2', - 'value': 'nested_checkbox_value_2' - }] - }] - },] - }] + "required": False, + "name": "bbox", + "tool": "rectangle", + "color": "#a23030", + "classifications": [ + { + "required": False, + "instructions": "nested", + "name": "nested", + "type": "radio", + "options": [ + { + "label": "radio_option_1", + "value": "radio_value_1", + "options": [ + { + "required": False, + "instructions": "nested_checkbox", + "name": "nested_checkbox", + "type": "checklist", + "options": [ + { + "label": "nested_checkbox_option_1", + "value": "nested_checkbox_value_1", + "options": [], + }, + { + "label": "nested_checkbox_option_2", + "value": "nested_checkbox_value_2", + }, + ], + } + ], + }, + ], + } + ], } polygon_tool = { - 'required': False, - 'name': 'polygon', - 'tool': 'polygon', - 'color': '#FF34FF', - 'classifications': [] + "required": False, + "name": "polygon", + "tool": "polygon", + "color": "#FF34FF", + "classifications": [], } polyline_tool = { - 'required': False, - 'name': 'polyline', - 'tool': 'line', - 'color': '#FF4A46', - 'classifications': [] + "required": False, + "name": "polyline", + "tool": "line", + "color": "#FF4A46", + "classifications": [], } point_tool = { - 'required': False, - 'name': 'point--', - 'tool': 'point', - 'color': '#008941', - 'classifications': [] + "required": False, + "name": "point--", + "tool": "point", + "color": "#008941", + "classifications": [], } entity_tool = { - 'required': False, - 'name': 'entity--', - 'tool': 'named-entity', - 'color': '#006FA6', - 'classifications': [] + "required": False, + "name": "entity--", + "tool": "named-entity", + "color": "#006FA6", + "classifications": [], } segmentation_tool = { - 'required': False, - 'name': 'segmentation--', - 'tool': 'superpixel', - 'color': '#A30059', - 'classifications': [] + "required": False, + "name": "segmentation--", + "tool": "superpixel", + "color": "#A30059", + "classifications": [], } raster_segmentation_tool = { - 'required': False, - 'name': 'segmentation_mask', - 'tool': 'raster-segmentation', - 'color': '#ff0000', - 'classifications': [] + "required": False, + "name": "segmentation_mask", + "tool": "raster-segmentation", + "color": "#ff0000", + "classifications": [], } checklist = { - 'required': - False, - 'instructions': - 'checklist', - 'name': - 'checklist', - 'type': - 'checklist', - 'options': [{ - 'label': 'option1', - 'value': 'option1' - }, { - 'label': 'option2', - 'value': 'option2' - }, { - 'label': 'optionN', - 'value': 'optionn' - }] + "required": False, + "instructions": "checklist", + "name": "checklist", + "type": "checklist", + "options": [ + {"label": "option1", "value": "option1"}, + {"label": "option2", "value": "option2"}, + {"label": "optionN", "value": "optionn"}, + ], } checklist_index = { - 'required': - False, - 'instructions': - 'checklist_index', - 'name': - 'checklist_index', - 'type': - 'checklist', - 'scope': - 'index', - 'options': [{ - 'label': 'option1_index', - 'value': 'option1_index' - }, { - 'label': 'option2_index', - 'value': 'option2_index' - }, { - 'label': 'optionN_index', - 'value': 'optionn_index' - }] + "required": False, + "instructions": "checklist_index", + "name": "checklist_index", + "type": "checklist", + "scope": "index", + "options": [ + {"label": "option1_index", "value": "option1_index"}, + {"label": "option2_index", "value": "option2_index"}, + {"label": "optionN_index", "value": "optionn_index"}, + ], } free_form_text = { - 'required': False, - 'instructions': 'text', - 'name': 'text', - 'type': 'text', - 'options': [] + "required": False, + "instructions": "text", + "name": "text", + "type": "text", + "options": [], } free_form_text_index = { - 'required': False, - 'instructions': 'text_index', - 'name': 'text_index', - 'type': 'text', - 'scope': 'index', - 'options': [] + "required": False, + "instructions": "text_index", + "name": "text_index", + "type": "text", + "scope": "index", + "options": [], } radio = { - 'required': - False, - 'instructions': - 'radio', - 'name': - 'radio', - 'type': - 'radio', - 'options': [{ - 'label': 'first_radio_answer', - 'value': 'first_radio_answer', - 'options': [] - }, { - 'label': 'second_radio_answer', - 'value': 'second_radio_answer', - 'options': [] - }] + "required": False, + "instructions": "radio", + "name": "radio", + "type": "radio", + "options": [ + { + "label": "first_radio_answer", + "value": "first_radio_answer", + "options": [], + }, + { + "label": "second_radio_answer", + "value": "second_radio_answer", + "options": [], + }, + ], } named_entity = { - 'tool': 'named-entity', - 'name': 'named-entity', - 'required': False, - 'color': '#A30059', - 'classifications': [], + "tool": "named-entity", + "name": "named-entity", + "required": False, + "color": "#A30059", + "classifications": [], } tools = [ @@ -243,53 +214,53 @@ def ontology(): named_entity, ] classifications = [ - checklist, checklist_index, free_form_text, free_form_text_index, radio + checklist, + checklist_index, + free_form_text, + free_form_text_index, + radio, ] return {"tools": tools, "classifications": classifications} @pytest.fixture def polygon_inference(prediction_id_mapping): - polygon = prediction_id_mapping['polygon'].copy() - polygon.update({ - "polygon": [{ - "x": 147.692, - "y": 118.154 - }, { - "x": 142.769, - "y": 104.923 - }, { - "x": 57.846, - "y": 118.769 - }, { - "x": 28.308, - "y": 169.846 - }] - }) - del polygon['tool'] + polygon = prediction_id_mapping["polygon"].copy() + polygon.update( + { + "polygon": [ + {"x": 147.692, "y": 118.154}, + {"x": 142.769, "y": 104.923}, + {"x": 57.846, "y": 118.769}, + {"x": 28.308, "y": 169.846}, + ] + } + ) + del polygon["tool"] return polygon @pytest.fixture -def configured_project_with_ontology(client, initial_dataset, ontology, - rand_gen, image_url): +def configured_project_with_ontology( + client, initial_dataset, ontology, rand_gen, image_url +): dataset = initial_dataset project = client.create_project( name=rand_gen(str), queue_mode=QueueMode.Batch, ) editor = list( - client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] + client.get_labeling_frontends(where=LabelingFrontend.name == "editor") + )[0] project.setup(editor, ontology) data_row_ids = [] - for _ in range(len(ontology['tools']) + len(ontology['classifications'])): + for _ in range(len(ontology["tools"]) + len(ontology["classifications"])): data_row_ids.append(dataset.create_data_row(row_data=image_url).uid) project.create_batch( rand_gen(str), data_row_ids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) + 5, # priority between 1(Highest) - 5(lowest) ) project.data_row_ids = data_row_ids yield project @@ -298,33 +269,44 @@ def configured_project_with_ontology(client, initial_dataset, ontology, @pytest.fixture def configured_project_without_data_rows(client, ontology, rand_gen): - project = client.create_project(name=rand_gen(str), - description=rand_gen(str), - queue_mode=QueueMode.Batch) + project = client.create_project( + name=rand_gen(str), + description=rand_gen(str), + queue_mode=QueueMode.Batch, + ) editor = list( - client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] + client.get_labeling_frontends(where=LabelingFrontend.name == "editor") + )[0] project.setup(editor, ontology) yield project project.delete() @pytest.fixture -def model_run_with_data_rows(client, configured_project_with_ontology, - model_run_predictions, model_run, - wait_for_label_processing): +def model_run_with_data_rows( + client, + configured_project_with_ontology, + model_run_predictions, + model_run, + wait_for_label_processing, +): configured_project_with_ontology.enable_model_assisted_labeling() - use_data_row_ids = [p['dataRow']['id'] for p in model_run_predictions] + use_data_row_ids = [p["dataRow"]["id"] for p in model_run_predictions] model_run.upsert_data_rows(use_data_row_ids) upload_task = LabelImport.create_from_objects( - client, configured_project_with_ontology.uid, - f"label-import-{uuid.uuid4()}", model_run_predictions) + client, + configured_project_with_ontology.uid, + f"label-import-{uuid.uuid4()}", + model_run_predictions, + ) upload_task.wait_until_done() - assert upload_task.state == AnnotationImportState.FINISHED, "Label Import did not finish" - assert len( - upload_task.errors - ) == 0, f"Label Import {upload_task.name} failed with errors {upload_task.errors}" + assert ( + upload_task.state == AnnotationImportState.FINISHED + ), "Label Import did not finish" + assert ( + len(upload_task.errors) == 0 + ), f"Label Import {upload_task.name} failed with errors {upload_task.errors}" labels = wait_for_label_processing(configured_project_with_ontology) label_ids = [label.uid for label in labels] model_run.upsert_labels(label_ids) @@ -334,8 +316,9 @@ def model_run_with_data_rows(client, configured_project_with_ontology, @pytest.fixture -def model_run_predictions(polygon_inference, rectangle_inference, - line_inference): +def model_run_predictions( + polygon_inference, rectangle_inference, line_inference +): # Not supporting mask since there isn't a signed url representing a seg mask to upload return [polygon_inference, rectangle_inference, line_inference] @@ -398,23 +381,26 @@ def prediction_id_mapping(configured_project_with_ontology): ontology = project.ontology().normalized result = {} - for idx, tool in enumerate(ontology['tools'] + ontology['classifications']): - if 'tool' in tool: - tool_type = tool['tool'] + for idx, tool in enumerate(ontology["tools"] + ontology["classifications"]): + if "tool" in tool: + tool_type = tool["tool"] else: - tool_type = tool[ - 'type'] if 'scope' not in tool else f"{tool['type']}_{tool['scope']}" # so 'checklist' of 'checklist_index' + tool_type = ( + tool["type"] + if "scope" not in tool + else f"{tool['type']}_{tool['scope']}" + ) # so 'checklist' of 'checklist_index' # TODO: remove this once we have a better way to associate multiple tools instances with a single tool type - if tool_type == 'rectangle': + if tool_type == "rectangle": value = { "uuid": str(uuid.uuid4()), - "schemaId": tool['featureSchemaId'], - "name": tool['name'], + "schemaId": tool["featureSchemaId"], + "name": tool["name"], "dataRow": { "id": project.data_row_ids[idx], }, - 'tool': tool + "tool": tool, } if tool_type not in result: result[tool_type] = [] @@ -422,86 +408,76 @@ def prediction_id_mapping(configured_project_with_ontology): else: result[tool_type] = { "uuid": str(uuid.uuid4()), - "schemaId": tool['featureSchemaId'], - "name": tool['name'], + "schemaId": tool["featureSchemaId"], + "name": tool["name"], "dataRow": { "id": project.data_row_ids[idx], }, - 'tool': tool + "tool": tool, } return result @pytest.fixture def line_inference(prediction_id_mapping): - line = prediction_id_mapping['line'].copy() + line = prediction_id_mapping["line"].copy() line.update( - {"line": [{ - "x": 147.692, - "y": 118.154 - }, { - "x": 150.692, - "y": 160.154 - }]}) - del line['tool'] + {"line": [{"x": 147.692, "y": 118.154}, {"x": 150.692, "y": 160.154}]} + ) + del line["tool"] return line @pytest.fixture def polygon_inference(prediction_id_mapping): - polygon = prediction_id_mapping['polygon'].copy() - polygon.update({ - "polygon": [{ - "x": 147.692, - "y": 118.154 - }, { - "x": 142.769, - "y": 104.923 - }, { - "x": 57.846, - "y": 118.769 - }, { - "x": 28.308, - "y": 169.846 - }] - }) - del polygon['tool'] + polygon = prediction_id_mapping["polygon"].copy() + polygon.update( + { + "polygon": [ + {"x": 147.692, "y": 118.154}, + {"x": 142.769, "y": 104.923}, + {"x": 57.846, "y": 118.769}, + {"x": 28.308, "y": 169.846}, + ] + } + ) + del polygon["tool"] return polygon def find_tool_by_name(tool_instances, name): for tool in tool_instances: - if tool['name'] == name: + if tool["name"] == name: return tool return None @pytest.fixture def rectangle_inference(prediction_id_mapping): - tool_instance = find_tool_by_name(prediction_id_mapping['rectangle'], - 'bbox') + tool_instance = find_tool_by_name( + prediction_id_mapping["rectangle"], "bbox" + ) rectangle = tool_instance.copy() - rectangle.update({ - "bbox": { - "top": 48, - "left": 58, - "height": 65, - "width": 12 - }, - 'classifications': [{ - "schemaId": - rectangle['tool']['classifications'][0]['featureSchemaId'], - "name": - rectangle['tool']['classifications'][0]['name'], - "answer": { - "schemaId": - rectangle['tool']['classifications'][0]['options'][0] - ['featureSchemaId'], - "name": - rectangle['tool']['classifications'][0]['options'][0] - ['value'] - } - }] - }) - del rectangle['tool'] + rectangle.update( + { + "bbox": {"top": 48, "left": 58, "height": 65, "width": 12}, + "classifications": [ + { + "schemaId": rectangle["tool"]["classifications"][0][ + "featureSchemaId" + ], + "name": rectangle["tool"]["classifications"][0]["name"], + "answer": { + "schemaId": rectangle["tool"]["classifications"][0][ + "options" + ][0]["featureSchemaId"], + "name": rectangle["tool"]["classifications"][0][ + "options" + ][0]["value"], + }, + } + ], + } + ) + del rectangle["tool"] return rectangle diff --git a/libs/labelbox/tests/data/export/legacy/test_export_catalog.py b/libs/labelbox/tests/data/export/legacy/test_export_catalog.py index b5aa72a35..635d307f0 100644 --- a/libs/labelbox/tests/data/export/legacy/test_export_catalog.py +++ b/libs/labelbox/tests/data/export/legacy/test_export_catalog.py @@ -1,7 +1,7 @@ import pytest -@pytest.mark.parametrize('data_rows', [3], indirect=True) +@pytest.mark.parametrize("data_rows", [3], indirect=True) def test_catalog_export_v2(client, export_v2_test_helpers, data_rows): datarow_filter_size = 2 data_row_ids = [dr.uid for dr in data_rows] @@ -10,10 +10,12 @@ def test_catalog_export_v2(client, export_v2_test_helpers, data_rows): filters = {"data_row_ids": data_row_ids[:datarow_filter_size]} task_results = export_v2_test_helpers.run_catalog_export_v2_task( - client, filters=filters, params=params) + client, filters=filters, params=params + ) # only 2 datarows should be exported assert len(task_results) == datarow_filter_size # only filtered datarows should be exported - assert set([dr['data_row']['id'] for dr in task_results - ]) == set(data_row_ids[:datarow_filter_size]) + assert set([dr["data_row"]["id"] for dr in task_results]) == set( + data_row_ids[:datarow_filter_size] + ) diff --git a/libs/labelbox/tests/data/export/legacy/test_export_dataset.py b/libs/labelbox/tests/data/export/legacy/test_export_dataset.py index e4a0b50c2..1d628dc86 100644 --- a/libs/labelbox/tests/data/export/legacy/test_export_dataset.py +++ b/libs/labelbox/tests/data/export/legacy/test_export_dataset.py @@ -1,15 +1,17 @@ import pytest -@pytest.mark.parametrize('data_rows', [3], indirect=True) +@pytest.mark.parametrize("data_rows", [3], indirect=True) def test_dataset_export_v2(export_v2_test_helpers, dataset, data_rows): data_row_ids = [dr.uid for dr in data_rows] params = {"performance_details": False, "label_details": False} task_results = export_v2_test_helpers.run_dataset_export_v2_task( - dataset, params=params) + dataset, params=params + ) assert len(task_results) == len(data_row_ids) - assert set([dr['data_row']['id'] for dr in task_results - ]) == set(data_row_ids) + assert set([dr["data_row"]["id"] for dr in task_results]) == set( + data_row_ids + ) # testing with a datarow ids filter datarow_filter_size = 2 @@ -19,13 +21,15 @@ def test_dataset_export_v2(export_v2_test_helpers, dataset, data_rows): filters = {"data_row_ids": data_row_ids[:datarow_filter_size]} task_results = export_v2_test_helpers.run_dataset_export_v2_task( - dataset, filters=filters, params=params) + dataset, filters=filters, params=params + ) # only 2 datarows should be exported assert len(task_results) == datarow_filter_size # only filtered datarows should be exported - assert set([dr['data_row']['id'] for dr in task_results - ]) == set(data_row_ids[:datarow_filter_size]) + assert set([dr["data_row"]["id"] for dr in task_results]) == set( + data_row_ids[:datarow_filter_size] + ) # testing with a global key and a datarow id filter datarow_filter_size = 2 @@ -35,10 +39,12 @@ def test_dataset_export_v2(export_v2_test_helpers, dataset, data_rows): filters = {"global_keys": global_keys[:datarow_filter_size]} task_results = export_v2_test_helpers.run_dataset_export_v2_task( - dataset, filters=filters, params=params) + dataset, filters=filters, params=params + ) # only 2 datarows should be exported assert len(task_results) == datarow_filter_size # only filtered datarows should be exported - assert set([dr['data_row']['global_key'] for dr in task_results - ]) == set(global_keys[:datarow_filter_size]) + assert set([dr["data_row"]["global_key"] for dr in task_results]) == set( + global_keys[:datarow_filter_size] + ) diff --git a/libs/labelbox/tests/data/export/legacy/test_export_model_run.py b/libs/labelbox/tests/data/export/legacy/test_export_model_run.py index 7dfd44f0c..2a06c334d 100644 --- a/libs/labelbox/tests/data/export/legacy/test_export_model_run.py +++ b/libs/labelbox/tests/data/export/legacy/test_export_model_run.py @@ -3,7 +3,7 @@ def _model_run_export_v2_results(model_run, task_name, params, num_retries=5): """Export model run results and retry if no results are returned.""" - while (num_retries > 0): + while num_retries > 0: task = model_run.export_v2(task_name, params=params) assert task.name == task_name task.wait_till_done() @@ -30,15 +30,22 @@ def test_model_run_export_v2(model_run_with_data_rows): for task_result in task_results: # Check export param handling - assert 'media_attributes' in task_result and task_result[ - 'media_attributes'] is not None - exported_model_run = task_result['experiments'][ - model_run.model_id]['runs'][model_run.uid] + assert ( + "media_attributes" in task_result + and task_result["media_attributes"] is not None + ) + exported_model_run = task_result["experiments"][model_run.model_id][ + "runs" + ][model_run.uid] task_label_ids_set = set( - map(lambda label: label['id'], exported_model_run['labels'])) + map(lambda label: label["id"], exported_model_run["labels"]) + ) task_prediction_ids_set = set( - map(lambda prediction: prediction['id'], - exported_model_run['predictions'])) + map( + lambda prediction: prediction["id"], + exported_model_run["predictions"], + ) + ) for label_id in task_label_ids_set: assert label_id in label_ids for prediction_id in task_prediction_ids_set: diff --git a/libs/labelbox/tests/data/export/legacy/test_export_project.py b/libs/labelbox/tests/data/export/legacy/test_export_project.py index f7716d5c5..3cd3b9226 100644 --- a/libs/labelbox/tests/data/export/legacy/test_export_project.py +++ b/libs/labelbox/tests/data/export/legacy/test_export_project.py @@ -10,9 +10,12 @@ IMAGE_URL = "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/potato.jpeg" -def test_project_export_v2(client, export_v2_test_helpers, - configured_project_with_label, - wait_for_data_row_processing): +def test_project_export_v2( + client, + export_v2_test_helpers, + configured_project_with_label, + wait_for_data_row_processing, +): project, dataset, data_row, label = configured_project_with_label data_row = wait_for_data_row_processing(client, data_row) label_id = label.uid @@ -23,55 +26,63 @@ def test_project_export_v2(client, export_v2_test_helpers, "include_labels": True, "media_type_override": MediaType.Image, "project_details": True, - "data_row_details": True + "data_row_details": True, } task_results = export_v2_test_helpers.run_project_export_v2_task( - project, task_name=task_name, params=params) + project, task_name=task_name, params=params + ) for task_result in task_results: - task_media_attributes = task_result['media_attributes'] - task_project = task_result['projects'][project.uid] + task_media_attributes = task_result["media_attributes"] + task_project = task_result["projects"][project.uid] task_project_label_ids_set = set( - map(lambda prediction: prediction['id'], task_project['labels'])) - task_project_details = task_project['project_details'] - task_data_row = task_result['data_row'] - task_data_row_details = task_data_row['details'] + map(lambda prediction: prediction["id"], task_project["labels"]) + ) + task_project_details = task_project["project_details"] + task_data_row = task_result["data_row"] + task_data_row_details = task_data_row["details"] assert label_id in task_project_label_ids_set # data row - assert task_data_row['id'] == data_row.uid - assert task_data_row['external_id'] == data_row.external_id - assert task_data_row['row_data'] == data_row.row_data + assert task_data_row["id"] == data_row.uid + assert task_data_row["external_id"] == data_row.external_id + assert task_data_row["row_data"] == data_row.row_data # data row details - assert task_data_row_details['dataset_id'] == dataset.uid - assert task_data_row_details['dataset_name'] == dataset.name + assert task_data_row_details["dataset_id"] == dataset.uid + assert task_data_row_details["dataset_name"] == dataset.name - assert task_data_row_details['last_activity_at'] is not None - assert task_data_row_details['created_by'] is not None + assert task_data_row_details["last_activity_at"] is not None + assert task_data_row_details["created_by"] is not None # media attributes - assert task_media_attributes['mime_type'] == data_row.media_attributes[ - 'mimeType'] + assert ( + task_media_attributes["mime_type"] + == data_row.media_attributes["mimeType"] + ) # project name and details - assert task_project['name'] == project.name + assert task_project["name"] == project.name batch = next(project.batches()) - assert task_project_details['batch_id'] == batch.uid - assert task_project_details['batch_name'] == batch.name - assert task_project_details['priority'] is not None - assert task_project_details[ - 'consensus_expected_label_count'] is not None - assert task_project_details['workflow_history'] is not None + assert task_project_details["batch_id"] == batch.uid + assert task_project_details["batch_name"] == batch.name + assert task_project_details["priority"] is not None + assert ( + task_project_details["consensus_expected_label_count"] is not None + ) + assert task_project_details["workflow_history"] is not None # label details - assert task_project['labels'][0]['id'] == label_id + assert task_project["labels"][0]["id"] == label_id -def test_project_export_v2_date_filters(client, export_v2_test_helpers, - configured_project_with_label, - wait_for_data_row_processing): +def test_project_export_v2_date_filters( + client, + export_v2_test_helpers, + configured_project_with_label, + wait_for_data_row_processing, +): project, _, data_row, label = configured_project_with_label data_row = wait_for_data_row_processing(client, data_row) label_id = label.uid @@ -81,7 +92,7 @@ def test_project_export_v2_date_filters(client, export_v2_test_helpers, filters = { "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "task_queue_status": "InReview" + "task_queue_status": "InReview", } # TODO: Right now we don't have a way to test this @@ -90,24 +101,27 @@ def test_project_export_v2_date_filters(client, export_v2_test_helpers, "performance_details": include_performance_details, "include_labels": True, "project_details": True, - "media_type_override": MediaType.Image + "media_type_override": MediaType.Image, } task_queues = project.task_queues() review_queue = next( - tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE") + tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE" + ) project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid) task_results = export_v2_test_helpers.run_project_export_v2_task( - project, task_name=task_name, filters=filters, params=params) + project, task_name=task_name, filters=filters, params=params + ) for task_result in task_results: - task_project = task_result['projects'][project.uid] + task_project = task_result["projects"][project.uid] task_project_label_ids_set = set( - map(lambda prediction: prediction['id'], task_project['labels'])) + map(lambda prediction: prediction["id"], task_project["labels"]) + ) assert label_id in task_project_label_ids_set - assert task_project['project_details']['workflow_status'] == 'IN_REVIEW' + assert task_project["project_details"]["workflow_status"] == "IN_REVIEW" # TODO: Add back in when we have a way to test this # if include_performance_details: @@ -124,9 +138,12 @@ def test_project_export_v2_date_filters(client, export_v2_test_helpers, export_v2_test_helpers.run_project_export_v2_task(project, filters=filters) -def test_project_export_v2_with_iso_date_filters(client, export_v2_test_helpers, - configured_project_with_label, - wait_for_data_row_processing): +def test_project_export_v2_with_iso_date_filters( + client, + export_v2_test_helpers, + configured_project_with_label, + wait_for_data_row_processing, +): project, _, data_row, label = configured_project_with_label data_row = wait_for_data_row_processing(client, data_row) label_id = label.uid @@ -135,33 +152,40 @@ def test_project_export_v2_with_iso_date_filters(client, export_v2_test_helpers, filters = { "last_activity_at": [ - "2000-01-01T00:00:00+0230", "2050-01-01T00:00:00+0230" + "2000-01-01T00:00:00+0230", + "2050-01-01T00:00:00+0230", ], "label_created_at": [ - "2000-01-01T00:00:00+0230", "2050-01-01T00:00:00+0230" - ] + "2000-01-01T00:00:00+0230", + "2050-01-01T00:00:00+0230", + ], } task_results = export_v2_test_helpers.run_project_export_v2_task( - project, task_name=task_name, filters=filters) - assert label_id == task_results[0]['projects'][ - project.uid]['labels'][0]['id'] + project, task_name=task_name, filters=filters + ) + assert ( + label_id == task_results[0]["projects"][project.uid]["labels"][0]["id"] + ) filters = {"last_activity_at": [None, "2050-01-01T00:00:00+0230"]} task_results = export_v2_test_helpers.run_project_export_v2_task( - project, task_name=task_name, filters=filters) - assert label_id == task_results[0]['projects'][ - project.uid]['labels'][0]['id'] + project, task_name=task_name, filters=filters + ) + assert ( + label_id == task_results[0]["projects"][project.uid]["labels"][0]["id"] + ) filters = {"label_created_at": ["2050-01-01T00:00:00+0230", None]} task_results = export_v2_test_helpers.run_project_export_v2_task( - project, task_name=task_name, filters=filters) + project, task_name=task_name, filters=filters + ) assert len(task_results) == 0 @pytest.mark.parametrize("data_rows", [3], indirect=True) def test_project_export_v2_datarows_filter( - export_v2_test_helpers, - configured_batch_project_with_multiple_datarows): + export_v2_test_helpers, configured_batch_project_with_multiple_datarows +): project, _, data_rows = configured_batch_project_with_multiple_datarows data_row_ids = [dr.uid for dr in data_rows] @@ -170,39 +194,47 @@ def test_project_export_v2_datarows_filter( filters = { "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "data_row_ids": data_row_ids[:datarow_filter_size] + "data_row_ids": data_row_ids[:datarow_filter_size], } params = {"data_row_details": True, "media_type_override": MediaType.Image} task_results = export_v2_test_helpers.run_project_export_v2_task( - project, filters=filters, params=params) + project, filters=filters, params=params + ) # only 2 datarows should be exported assert len(task_results) == datarow_filter_size # only filtered datarows should be exported - assert set([dr['data_row']['id'] for dr in task_results - ]) == set(data_row_ids[:datarow_filter_size]) + assert set([dr["data_row"]["id"] for dr in task_results]) == set( + data_row_ids[:datarow_filter_size] + ) global_keys = [dr.global_key for dr in data_rows] filters = { "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "global_keys": global_keys[:datarow_filter_size] + "global_keys": global_keys[:datarow_filter_size], } params = {"data_row_details": True, "media_type_override": MediaType.Image} task_results = export_v2_test_helpers.run_project_export_v2_task( - project, filters=filters, params=params) + project, filters=filters, params=params + ) # only 2 datarows should be exported assert len(task_results) == datarow_filter_size # only filtered datarows should be exported - assert set([dr['data_row']['global_key'] for dr in task_results - ]) == set(global_keys[:datarow_filter_size]) + assert set([dr["data_row"]["global_key"] for dr in task_results]) == set( + global_keys[:datarow_filter_size] + ) def test_batch_project_export_v2( - configured_batch_project_with_label: Tuple[Project, Dataset, DataRow, - Label], - export_v2_test_helpers, dataset: Dataset, image_url: str): + configured_batch_project_with_label: Tuple[ + Project, Dataset, DataRow, Label + ], + export_v2_test_helpers, + dataset: Dataset, + image_url: str, +): project, dataset, *_ = configured_batch_project_with_label batch = list(project.batches())[0] @@ -214,23 +246,24 @@ def test_batch_project_export_v2( params = { "include_performance_details": True, "include_labels": True, - "media_type_override": MediaType.Image + "media_type_override": MediaType.Image, } task_name = "test_batch_export_v2" - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image" - }, - ] * 2) + task = dataset.create_data_rows( + [ + {"row_data": image_url, "external_id": "my-image"}, + ] + * 2 + ) task.wait_till_done() data_rows = [dr.uid for dr in list(dataset.export_data_rows())] - batch_one = f'batch one {uuid.uuid4()}' + batch_one = f"batch one {uuid.uuid4()}" # This test creates two batches, only one batch should be exporter # Creatin second batch that will not be used in the export due to the filter: batch_id project.create_batch(batch_one, data_rows) task_results = export_v2_test_helpers.run_project_export_v2_task( - project, task_name=task_name, filters=filters, params=params) - assert (batch.size == len(task_results)) + project, task_name=task_name, filters=filters, params=params + ) + assert batch.size == len(task_results) diff --git a/libs/labelbox/tests/data/export/legacy/test_export_slice.py b/libs/labelbox/tests/data/export/legacy/test_export_slice.py index 2caa6b227..3d1fb7898 100644 --- a/libs/labelbox/tests/data/export/legacy/test_export_slice.py +++ b/libs/labelbox/tests/data/export/legacy/test_export_slice.py @@ -2,15 +2,15 @@ @pytest.mark.skip( - 'Skipping until we have a way to create slices programatically') + "Skipping until we have a way to create slices programatically" +) def test_export_v2_slice(client): # Since we don't have CRUD for slices, we'll just use the one that's already there SLICE_ID = "clk04g1e4000ryb0rgsvy1dty" slice = client.get_catalog_slice(SLICE_ID) - task = slice.export_v2(params={ - "performance_details": False, - "label_details": True - }) + task = slice.export_v2( + params={"performance_details": False, "label_details": True} + ) task.wait_till_done() assert task.status == "COMPLETE" assert task.errors is None diff --git a/libs/labelbox/tests/data/export/legacy/test_export_video.py b/libs/labelbox/tests/data/export/legacy/test_export_video.py index 3a0cb4149..75a57eca9 100644 --- a/libs/labelbox/tests/data/export/legacy/test_export_video.py +++ b/libs/labelbox/tests/data/export/legacy/test_export_video.py @@ -25,7 +25,6 @@ def test_export_v2_video( bbox_video_annotation_objects, rand_gen, ): - project = configured_project_without_data_rows project_id = project.uid labels = [] @@ -34,17 +33,20 @@ def test_export_v2_video( project.create_batch( rand_gen(str), data_row_uids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) + 5, # priority between 1(Highest) - 5(lowest) ) for data_row_uid in data_row_uids: labels = [ - lb_types.Label(data=VideoData(uid=data_row_uid), - annotations=bbox_video_annotation_objects) + lb_types.Label( + data=VideoData(uid=data_row_uid), + annotations=bbox_video_annotation_objects, + ) ] label_import = lb.LabelImport.create_from_objects( - client, project_id, f'test-import-{project_id}', labels) + client, project_id, f"test-import-{project_id}", labels + ) label_import.wait_until_done() assert label_import.state == AnnotationImportState.FINISHED @@ -53,13 +55,14 @@ def test_export_v2_video( num_retries = 5 task = None - while (num_retries > 0): + while num_retries > 0: task = project.export_v2( params={ "performance_details": False, "label_details": True, - "interpolated_frames": True - }) + "interpolated_frames": True, + } + ) task.wait_till_done() assert task.status == "COMPLETE" assert task.errors is None @@ -70,129 +73,135 @@ def test_export_v2_video( break export_data = task.result - data_row_export = export_data[0]['data_row'] - assert data_row_export['global_key'] == video_data_row['global_key'] - assert data_row_export['row_data'] == video_data_row['row_data'] - assert export_data[0]['media_attributes']['mime_type'] == 'video/mp4' - assert export_data[0]['media_attributes'][ - 'frame_rate'] == 10 # as per the video_data fixture - assert export_data[0]['media_attributes'][ - 'frame_count'] == 100 # as per the video_data fixture + data_row_export = export_data[0]["data_row"] + assert data_row_export["global_key"] == video_data_row["global_key"] + assert data_row_export["row_data"] == video_data_row["row_data"] + assert export_data[0]["media_attributes"]["mime_type"] == "video/mp4" + assert ( + export_data[0]["media_attributes"]["frame_rate"] == 10 + ) # as per the video_data fixture + assert ( + export_data[0]["media_attributes"]["frame_count"] == 100 + ) # as per the video_data fixture expected_export_label = { - 'label_kind': 'Video', - 'version': '1.0.0', - 'id': 'clgjnpysl000xi3zxtnp29fug', - 'label_details': { - 'created_at': '2023-04-16T17:04:23+00:00', - 'updated_at': '2023-04-16T17:04:23+00:00', - 'created_by': 'vbrodsky@labelbox.com', - 'content_last_updated_at': '2023-04-16T17:04:23+00:00', - 'reviews': [] + "label_kind": "Video", + "version": "1.0.0", + "id": "clgjnpysl000xi3zxtnp29fug", + "label_details": { + "created_at": "2023-04-16T17:04:23+00:00", + "updated_at": "2023-04-16T17:04:23+00:00", + "created_by": "vbrodsky@labelbox.com", + "content_last_updated_at": "2023-04-16T17:04:23+00:00", + "reviews": [], }, - 'annotations': { - 'frames': { - '13': { - 'objects': { - 'clgjnpyse000ui3zx6fr1d880': { - 'feature_id': 'clgjnpyse000ui3zx6fr1d880', - 'name': 'bbox', - 'annotation_kind': 'VideoBoundingBox', - 'classifications': [{ - 'feature_id': 'clgjnpyse000vi3zxtgtfh01y', - 'name': 'nested', - 'radio_answer': { - 'feature_id': 'clgjnpyse000wi3zxnxgv53ps', - 'name': 'radio_option_1', - 'classifications': [] + "annotations": { + "frames": { + "13": { + "objects": { + "clgjnpyse000ui3zx6fr1d880": { + "feature_id": "clgjnpyse000ui3zx6fr1d880", + "name": "bbox", + "annotation_kind": "VideoBoundingBox", + "classifications": [ + { + "feature_id": "clgjnpyse000vi3zxtgtfh01y", + "name": "nested", + "radio_answer": { + "feature_id": "clgjnpyse000wi3zxnxgv53ps", + "name": "radio_option_1", + "classifications": [], + }, } - }], - 'bounding_box': { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - } + ], + "bounding_box": { + "top": 98.0, + "left": 146.0, + "height": 243.0, + "width": 236.0, + }, } }, - 'classifications': [] + "classifications": [], }, - '18': { - 'objects': { - 'clgjnpyse000ui3zx6fr1d880': { - 'feature_id': 'clgjnpyse000ui3zx6fr1d880', - 'name': 'bbox', - 'annotation_kind': 'VideoBoundingBox', - 'classifications': [{ - 'feature_id': 'clgjnpyse000vi3zxtgtfh01y', - 'name': 'nested', - 'radio_answer': { - 'feature_id': 'clgjnpyse000wi3zxnxgv53ps', - 'name': 'radio_option_1', - 'classifications': [] + "18": { + "objects": { + "clgjnpyse000ui3zx6fr1d880": { + "feature_id": "clgjnpyse000ui3zx6fr1d880", + "name": "bbox", + "annotation_kind": "VideoBoundingBox", + "classifications": [ + { + "feature_id": "clgjnpyse000vi3zxtgtfh01y", + "name": "nested", + "radio_answer": { + "feature_id": "clgjnpyse000wi3zxnxgv53ps", + "name": "radio_option_1", + "classifications": [], + }, } - }], - 'bounding_box': { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - } + ], + "bounding_box": { + "top": 98.0, + "left": 146.0, + "height": 243.0, + "width": 236.0, + }, } }, - 'classifications': [] + "classifications": [], }, - '19': { - 'objects': { - 'clgjnpyse000ui3zx6fr1d880': { - 'feature_id': 'clgjnpyse000ui3zx6fr1d880', - 'name': 'bbox', - 'annotation_kind': 'VideoBoundingBox', - 'classifications': [], - 'bounding_box': { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - } + "19": { + "objects": { + "clgjnpyse000ui3zx6fr1d880": { + "feature_id": "clgjnpyse000ui3zx6fr1d880", + "name": "bbox", + "annotation_kind": "VideoBoundingBox", + "classifications": [], + "bounding_box": { + "top": 98.0, + "left": 146.0, + "height": 243.0, + "width": 236.0, + }, } }, - 'classifications': [] - } - }, - 'segments': { - 'clgjnpyse000ui3zx6fr1d880': [[13, 13], [18, 19]] + "classifications": [], + }, }, - 'key_frame_feature_map': { - 'clgjnpyse000ui3zx6fr1d880': { - '13': True, - '18': False, - '19': True + "segments": {"clgjnpyse000ui3zx6fr1d880": [[13, 13], [18, 19]]}, + "key_frame_feature_map": { + "clgjnpyse000ui3zx6fr1d880": { + "13": True, + "18": False, + "19": True, } }, - 'classifications': [] - } + "classifications": [], + }, } - project_export_labels = export_data[0]['projects'][project_id]['labels'] - assert (len(project_export_labels) == len(labels) - ) #note we create 1 label per data row, 1 data row so 1 label + project_export_labels = export_data[0]["projects"][project_id]["labels"] + assert len(project_export_labels) == len( + labels + ) # note we create 1 label per data row, 1 data row so 1 label export_label = project_export_labels[0] - assert (export_label['label_kind']) == 'Video' + assert (export_label["label_kind"]) == "Video" - assert (export_label['label_details'].keys() - ) == expected_export_label['label_details'].keys() + assert (export_label["label_details"].keys()) == expected_export_label[ + "label_details" + ].keys() expected_frames_ids = [ vannotation.frame for vannotation in bbox_video_annotation_objects ] - export_annotations = export_label['annotations'] - export_frames = export_annotations['frames'] + export_annotations = export_label["annotations"] + export_frames = export_annotations["frames"] export_frames_ids = [int(frame_id) for frame_id in export_frames.keys()] all_frames_exported = [] for value in expected_frames_ids: # note need to understand why we are exporting more frames than we created if value not in export_frames_ids: all_frames_exported.append(value) - assert (len(all_frames_exported) == 0) + assert len(all_frames_exported) == 0 # BEGINNING OF THE VIDEO INTERPOLATION ASSERTIONS first_frame_id = bbox_video_annotation_objects[0].frame @@ -203,42 +212,50 @@ def test_export_v2_video( assert export_frames_ids == expected_frame_ids - exported_objects_dict = export_frames[str(first_frame_id)]['objects'] + exported_objects_dict = export_frames[str(first_frame_id)]["objects"] # Get the label ID first_exported_label_id = list(exported_objects_dict.keys())[0] # Since the bounding box moves to the right, the interpolated frame content should start a little bit more far to the right - assert export_frames[str(first_frame_id + 1)]['objects'][ - first_exported_label_id]['bounding_box']['left'] > export_frames[ - str(first_frame_id - )]['objects'][first_exported_label_id]['bounding_box']['left'] + assert ( + export_frames[str(first_frame_id + 1)]["objects"][ + first_exported_label_id + ]["bounding_box"]["left"] + > export_frames[str(first_frame_id)]["objects"][ + first_exported_label_id + ]["bounding_box"]["left"] + ) # But it shouldn't be further than the last frame - assert export_frames[str(first_frame_id + 1)]['objects'][ - first_exported_label_id]['bounding_box']['left'] < export_frames[ - str(last_frame_id - )]['objects'][first_exported_label_id]['bounding_box']['left'] + assert ( + export_frames[str(first_frame_id + 1)]["objects"][ + first_exported_label_id + ]["bounding_box"]["left"] + < export_frames[str(last_frame_id)]["objects"][first_exported_label_id][ + "bounding_box" + ]["left"] + ) # END OF THE VIDEO INTERPOLATION ASSERTIONS - frame_with_nested_classifications = export_frames['13'] + frame_with_nested_classifications = export_frames["13"] annotation = None - for _, a in frame_with_nested_classifications['objects'].items(): - if a['name'] == 'bbox': + for _, a in frame_with_nested_classifications["objects"].items(): + if a["name"] == "bbox": annotation = a break - assert (annotation is not None) - assert (annotation['annotation_kind'] == 'VideoBoundingBox') - assert (annotation['classifications']) - assert (annotation['bounding_box'] == { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - }) - classifications = annotation['classifications'] - classification = classifications[0]['radio_answer'] - assert (classification['name'] == 'radio_option_1') - subclassifications = classification['classifications'] + assert annotation is not None + assert annotation["annotation_kind"] == "VideoBoundingBox" + assert annotation["classifications"] + assert annotation["bounding_box"] == { + "top": 98.0, + "left": 146.0, + "height": 243.0, + "width": 236.0, + } + classifications = annotation["classifications"] + classification = classifications[0]["radio_answer"] + assert classification["name"] == "radio_option_1" + subclassifications = classification["classifications"] # NOTE predictions services does not support nested classifications at the moment, see # https://labelbox.atlassian.net/browse/AL-5588 - assert (len(subclassifications) == 0) + assert len(subclassifications) == 0 diff --git a/libs/labelbox/tests/data/export/legacy/test_legacy_export.py b/libs/labelbox/tests/data/export/legacy/test_legacy_export.py index 31ae8ca91..93b803f7f 100644 --- a/libs/labelbox/tests/data/export/legacy/test_legacy_export.py +++ b/libs/labelbox/tests/data/export/legacy/test_legacy_export.py @@ -13,8 +13,10 @@ @pytest.mark.skip(reason="broken export v1 api, to be retired soon") def test_export_annotations_nested_checklist( - client, configured_project_with_complex_ontology, - wait_for_data_row_processing): + client, + configured_project_with_complex_ontology, + wait_for_data_row_processing, +): project, data_row = configured_project_with_complex_ontology data_row = wait_for_data_row_processing(client, data_row) ontology = project.ontology().normalized @@ -22,43 +24,44 @@ def test_export_annotations_nested_checklist( tool = ontology["tools"][0] nested_check = [ - subc for subc in tool["classifications"] + subc + for subc in tool["classifications"] if subc["name"] == "test-checklist-class" ][0] - data = [{ - "uuid": - str(uuid.uuid4()), - "schemaId": - tool['featureSchemaId'], - "dataRow": { - "id": data_row.uid - }, - "bbox": { - "top": 20, - "left": 20, - "height": 50, - "width": 50 - }, - "classifications": [{ - "schemaId": - nested_check["featureSchemaId"], - "answers": [ - { - "schemaId": nested_check["options"][0]["featureSchemaId"] - }, + data = [ + { + "uuid": str(uuid.uuid4()), + "schemaId": tool["featureSchemaId"], + "dataRow": {"id": data_row.uid}, + "bbox": {"top": 20, "left": 20, "height": 50, "width": 50}, + "classifications": [ { - "schemaId": nested_check["options"][1]["featureSchemaId"] - }, - ] - }] - }] - task = LabelImport.create_from_objects(client, project.uid, - f'label-import-{uuid.uuid4()}', data) + "schemaId": nested_check["featureSchemaId"], + "answers": [ + { + "schemaId": nested_check["options"][0][ + "featureSchemaId" + ] + }, + { + "schemaId": nested_check["options"][1][ + "featureSchemaId" + ] + }, + ], + } + ], + } + ] + task = LabelImport.create_from_objects( + client, project.uid, f"label-import-{uuid.uuid4()}", data + ) task.wait_until_done() labels = project.label_generator() object_annotation = [ - annot for annot in next(labels).annotations + annot + for annot in next(labels).annotations if isinstance(annot, ObjectAnnotation) ][0] @@ -67,29 +70,26 @@ def test_export_annotations_nested_checklist( @pytest.mark.skip(reason="broken export v1 api, to be retired soon") -def test_export_filtered_dates(client, - configured_project_with_complex_ontology): +def test_export_filtered_dates( + client, configured_project_with_complex_ontology +): project, data_row = configured_project_with_complex_ontology ontology = project.ontology().normalized tool = ontology["tools"][0] - data = [{ - "uuid": str(uuid.uuid4()), - "schemaId": tool['featureSchemaId'], - "dataRow": { - "id": data_row.uid - }, - "bbox": { - "top": 20, - "left": 20, - "height": 50, - "width": 50 + data = [ + { + "uuid": str(uuid.uuid4()), + "schemaId": tool["featureSchemaId"], + "dataRow": {"id": data_row.uid}, + "bbox": {"top": 20, "left": 20, "height": 50, "width": 50}, } - }] + ] - task = LabelImport.create_from_objects(client, project.uid, - f'label-import-{uuid.uuid4()}', data) + task = LabelImport.create_from_objects( + client, project.uid, f"label-import-{uuid.uuid4()}", data + ) task.wait_until_done() regular_export = project.export_labels(download=True) @@ -99,39 +99,37 @@ def test_export_filtered_dates(client, assert len(filtered_export) == 1 filtered_export_with_time = project.export_labels( - download=True, start="2020-01-01 00:00:01") + download=True, start="2020-01-01 00:00:01" + ) assert len(filtered_export_with_time) == 1 - empty_export = project.export_labels(download=True, - start="2020-01-01", - end="2020-01-02") + empty_export = project.export_labels( + download=True, start="2020-01-01", end="2020-01-02" + ) assert len(empty_export) == 0 @pytest.mark.skip(reason="broken export v1 api, to be retired soon") -def test_export_filtered_activity(client, - configured_project_with_complex_ontology): +def test_export_filtered_activity( + client, configured_project_with_complex_ontology +): project, data_row = configured_project_with_complex_ontology ontology = project.ontology().normalized tool = ontology["tools"][0] - data = [{ - "uuid": str(uuid.uuid4()), - "schemaId": tool['featureSchemaId'], - "dataRow": { - "id": data_row.uid - }, - "bbox": { - "top": 20, - "left": 20, - "height": 50, - "width": 50 + data = [ + { + "uuid": str(uuid.uuid4()), + "schemaId": tool["featureSchemaId"], + "dataRow": {"id": data_row.uid}, + "bbox": {"top": 20, "left": 20, "height": 50, "width": 50}, } - }] + ] - task = LabelImport.create_from_objects(client, project.uid, - f'label-import-{uuid.uuid4()}', data) + task = LabelImport.create_from_objects( + client, project.uid, f"label-import-{uuid.uuid4()}", data + ) task.wait_until_done() regular_export = project.export_labels(download=True) @@ -140,35 +138,41 @@ def test_export_filtered_activity(client, filtered_export = project.export_labels( download=True, last_activity_start="2020-01-01", - last_activity_end=(datetime.datetime.now() + - datetime.timedelta(days=2)).strftime("%Y-%m-%d")) + last_activity_end=( + datetime.datetime.now() + datetime.timedelta(days=2) + ).strftime("%Y-%m-%d"), + ) assert len(filtered_export) == 1 filtered_export_with_time = project.export_labels( - download=True, last_activity_start="2020-01-01 00:00:01") + download=True, last_activity_start="2020-01-01 00:00:01" + ) assert len(filtered_export_with_time) == 1 empty_export = project.export_labels( download=True, - last_activity_start=(datetime.datetime.now() + - datetime.timedelta(days=2)).strftime("%Y-%m-%d"), + last_activity_start=( + datetime.datetime.now() + datetime.timedelta(days=2) + ).strftime("%Y-%m-%d"), ) empty_export = project.export_labels( download=True, - last_activity_end=(datetime.datetime.now() - - datetime.timedelta(days=1)).strftime("%Y-%m-%d")) + last_activity_end=( + datetime.datetime.now() - datetime.timedelta(days=1) + ).strftime("%Y-%m-%d"), + ) assert len(empty_export) == 0 def test_export_data_rows(project: Project, dataset: Dataset): n_data_rows = 2 - task = dataset.create_data_rows([ - { - "row_data": IMAGE_URL, - "external_id": "my-image" - }, - ] * n_data_rows) + task = dataset.create_data_rows( + [ + {"row_data": IMAGE_URL, "external_id": "my-image"}, + ] + * n_data_rows + ) task.wait_till_done() data_rows = [dr.uid for dr in list(dataset.export_data_rows())] @@ -196,9 +200,9 @@ def test_label_export(configured_project_with_label): exported_labels_url = project.export_labels() assert exported_labels_url is not None exported_labels = requests.get(exported_labels_url) - labels = [example['ID'] for example in exported_labels.json()] + labels = [example["ID"] for example in exported_labels.json()] assert labels[0] == label_id - #TODO: Add test for bulk export back. + # TODO: Add test for bulk export back. # The new exporter doesn't work with the create_label mutation @@ -233,11 +237,12 @@ def test_dataset_export(dataset, image_url): @pytest.mark.skip(reason="broken export v1 api, to be retired soon") def test_data_row_export_with_empty_media_attributes( - client, configured_project_with_label, wait_for_data_row_processing): + client, configured_project_with_label, wait_for_data_row_processing +): project, _, data_row, _ = configured_project_with_label data_row = wait_for_data_row_processing(client, data_row) labels = list(project.label_generator()) - assert len( - labels - ) == 1, "Label export job unexpectedly returned an empty result set`" + assert ( + len(labels) == 1 + ), "Label export job unexpectedly returned an empty result set`" assert labels[0].data.media_attributes == {} diff --git a/libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py index 0d98d8a89..3e4efbc46 100644 --- a/libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py +++ b/libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py @@ -7,9 +7,9 @@ class TestExportDataRow: - - def test_with_data_row_object(self, client, data_row, - wait_for_data_row_processing): + def test_with_data_row_object( + self, client, data_row, wait_for_data_row_processing + ): data_row = wait_for_data_row_processing(client, data_row) time.sleep(7) # temp fix for ES indexing delay export_task = DataRow.export( @@ -22,14 +22,20 @@ def test_with_data_row_object(self, client, data_row, assert isinstance(export_task, ExportTask) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) == 1 - assert (json.loads(list(export_task.get_stream())[0].json_str) - ["data_row"]["id"] == data_row.uid) - - def test_with_data_row_object_buffered(self, client, data_row, - wait_for_data_row_processing): + assert ( + json.loads(list(export_task.get_stream())[0].json_str)["data_row"][ + "id" + ] + == data_row.uid + ) + + def test_with_data_row_object_buffered( + self, client, data_row, wait_for_data_row_processing + ): data_row = wait_for_data_row_processing(client, data_row) time.sleep(7) # temp fix for ES indexing delay export_task = DataRow.export( @@ -42,30 +48,42 @@ def test_with_data_row_object_buffered(self, client, data_row, assert isinstance(export_task, ExportTask) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) == 1 - assert list(export_task.get_buffered_stream())[0].json["data_row"]["id"] == data_row.uid + assert ( + list(export_task.get_buffered_stream())[0].json["data_row"]["id"] + == data_row.uid + ) def test_with_id(self, client, data_row, wait_for_data_row_processing): data_row = wait_for_data_row_processing(client, data_row) time.sleep(7) # temp fix for ES indexing delay - export_task = DataRow.export(client=client, - data_rows=[data_row.uid], - task_name="TestExportDataRow:test_with_id") + export_task = DataRow.export( + client=client, + data_rows=[data_row.uid], + task_name="TestExportDataRow:test_with_id", + ) export_task.wait_till_done() assert export_task.status == "COMPLETE" assert isinstance(export_task, ExportTask) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) == 1 - assert (json.loads(list(export_task.get_stream())[0].json_str) - ["data_row"]["id"] == data_row.uid) + assert ( + json.loads(list(export_task.get_stream())[0].json_str)["data_row"][ + "id" + ] + == data_row.uid + ) - def test_with_global_key(self, client, data_row, - wait_for_data_row_processing): + def test_with_global_key( + self, client, data_row, wait_for_data_row_processing + ): data_row = wait_for_data_row_processing(client, data_row) time.sleep(7) # temp fix for ES indexing delay export_task = DataRow.export( @@ -78,11 +96,16 @@ def test_with_global_key(self, client, data_row, assert isinstance(export_task, ExportTask) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) == 1 - assert (json.loads(list(export_task.get_stream())[0].json_str) - ["data_row"]["id"] == data_row.uid) + assert ( + json.loads(list(export_task.get_stream())[0].json_str)["data_row"][ + "id" + ] + == data_row.uid + ) def test_with_invalid_id(self, client): export_task = DataRow.export( @@ -95,7 +118,10 @@ def test_with_invalid_id(self, client): assert isinstance(export_task, ExportTask) assert export_task.has_result() is False assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) is None - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) is None + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) + is None + ) + assert ( + export_task.get_total_lines(stream_type=StreamType.RESULT) is None + ) diff --git a/libs/labelbox/tests/data/export/streamable/test_export_dataset_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_dataset_streamable.py index e31f17c44..57f617a00 100644 --- a/libs/labelbox/tests/data/export/streamable/test_export_dataset_streamable.py +++ b/libs/labelbox/tests/data/export/streamable/test_export_dataset_streamable.py @@ -6,7 +6,6 @@ class TestExportDataset: - @pytest.mark.parametrize("data_rows", [3], indirect=True) def test_export(self, dataset, data_rows): expected_data_row_ids = [dr.uid for dr in data_rows] @@ -18,61 +17,82 @@ def test_export(self, dataset, data_rows): assert isinstance(export_task, ExportTask) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == len(expected_data_row_ids) + stream_type=StreamType.RESULT + ) == len(expected_data_row_ids) data_row_ids = list( - map(lambda x: json.loads(x.json_str)["data_row"]["id"], - export_task.get_stream())) + map( + lambda x: json.loads(x.json_str)["data_row"]["id"], + export_task.get_stream(), + ) + ) assert data_row_ids.sort() == expected_data_row_ids.sort() @pytest.mark.parametrize("data_rows", [3], indirect=True) def test_with_data_row_filter(self, dataset, data_rows): datarow_filter_size = 3 - expected_data_row_ids = [dr.uid for dr in data_rows - ][:datarow_filter_size] + expected_data_row_ids = [dr.uid for dr in data_rows][ + :datarow_filter_size + ] filters = {"data_row_ids": expected_data_row_ids} export_task = dataset.export( filters=filters, - task_name="TestExportDataset:test_with_data_row_filter") + task_name="TestExportDataset:test_with_data_row_filter", + ) export_task.wait_till_done() assert export_task.status == "COMPLETE" assert isinstance(export_task, ExportTask) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == datarow_filter_size + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) + assert ( + export_task.get_total_lines(stream_type=StreamType.RESULT) + == datarow_filter_size + ) data_row_ids = list( - map(lambda x: json.loads(x.json_str)["data_row"]["id"], - export_task.get_stream())) + map( + lambda x: json.loads(x.json_str)["data_row"]["id"], + export_task.get_stream(), + ) + ) assert data_row_ids.sort() == expected_data_row_ids.sort() @pytest.mark.parametrize("data_rows", [3], indirect=True) def test_with_global_key_filter(self, dataset, data_rows): datarow_filter_size = 2 - expected_global_keys = [dr.global_key for dr in data_rows - ][:datarow_filter_size] + expected_global_keys = [dr.global_key for dr in data_rows][ + :datarow_filter_size + ] filters = {"global_keys": expected_global_keys} export_task = dataset.export( filters=filters, - task_name="TestExportDataset:test_with_global_key_filter") + task_name="TestExportDataset:test_with_global_key_filter", + ) export_task.wait_till_done() assert export_task.status == "COMPLETE" assert isinstance(export_task, ExportTask) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == datarow_filter_size + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) + assert ( + export_task.get_total_lines(stream_type=StreamType.RESULT) + == datarow_filter_size + ) global_keys = list( - map(lambda x: json.loads(x.json_str)["data_row"]["global_key"], - export_task.get_stream())) + map( + lambda x: json.loads(x.json_str)["data_row"]["global_key"], + export_task.get_stream(), + ) + ) assert global_keys.sort() == expected_global_keys.sort() diff --git a/libs/labelbox/tests/data/export/streamable/test_export_embeddings_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_embeddings_streamable.py index b0c683486..071acbb5b 100644 --- a/libs/labelbox/tests/data/export/streamable/test_export_embeddings_streamable.py +++ b/libs/labelbox/tests/data/export/streamable/test_export_embeddings_streamable.py @@ -5,13 +5,15 @@ class TestExportEmbeddings: - - def test_export_embeddings_precomputed(self, client, dataset, environ, - image_url): - data_row_specs = [{ - "row_data": image_url, - "external_id": "image", - }] + def test_export_embeddings_precomputed( + self, client, dataset, environ, image_url + ): + data_row_specs = [ + { + "row_data": image_url, + "external_id": "image", + } + ] task = dataset.create_data_rows(data_row_specs) task.wait_till_done() export_task = dataset.export(params={"embeddings": True}) @@ -21,30 +23,42 @@ def test_export_embeddings_precomputed(self, client, dataset, environ, assert export_task.has_errors() is False results = [] - export_task.get_stream(converter=JsonConverter(), - stream_type=StreamType.RESULT).start( - stream_handler=lambda output: results.append( - json.loads(output.json_str))) + export_task.get_stream( + converter=JsonConverter(), stream_type=StreamType.RESULT + ).start( + stream_handler=lambda output: results.append( + json.loads(output.json_str) + ) + ) assert len(results) == len(data_row_specs) result = results[0] assert "embeddings" in result assert len(result["embeddings"]) > 0 - assert result["embeddings"][0][ - "name"] == "Image Embedding V2 (CLIP ViT-B/32)" + assert ( + result["embeddings"][0]["name"] + == "Image Embedding V2 (CLIP ViT-B/32)" + ) assert len(result["embeddings"][0]["values"]) == 1 - def test_export_embeddings_custom(self, client, dataset, image_url, - embedding): + def test_export_embeddings_custom( + self, client, dataset, image_url, embedding + ): vector = [random.uniform(1.0, 2.0) for _ in range(embedding.dims)] - import_task = dataset.create_data_rows([{ - "row_data": image_url, - "embeddings": [{ - "embedding_id": embedding.id, - "vector": vector, - }], - }]) + import_task = dataset.create_data_rows( + [ + { + "row_data": image_url, + "embeddings": [ + { + "embedding_id": embedding.id, + "vector": vector, + } + ], + } + ] + ) import_task.wait_till_done() assert import_task.status == "COMPLETE" @@ -55,15 +69,19 @@ def test_export_embeddings_custom(self, client, dataset, image_url, assert export_task.has_errors() is False results = [] - export_task.get_stream(converter=JsonConverter(), - stream_type=StreamType.RESULT).start( - stream_handler=lambda output: results.append( - json.loads(output.json_str))) + export_task.get_stream( + converter=JsonConverter(), stream_type=StreamType.RESULT + ).start( + stream_handler=lambda output: results.append( + json.loads(output.json_str) + ) + ) assert len(results) == 1 assert "embeddings" in results[0] - assert (len(results[0]["embeddings"]) - >= 1) # should at least contain the custom embedding + assert ( + len(results[0]["embeddings"]) >= 1 + ) # should at least contain the custom embedding for emb in results[0]["embeddings"]: if emb["id"] == embedding.id: assert emb["name"] == embedding.name diff --git a/libs/labelbox/tests/data/export/streamable/test_export_model_run_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_model_run_streamable.py index 0d1244660..ada493fc3 100644 --- a/libs/labelbox/tests/data/export/streamable/test_export_model_run_streamable.py +++ b/libs/labelbox/tests/data/export/streamable/test_export_model_run_streamable.py @@ -5,7 +5,6 @@ class TestExportModelRun: - def test_export(self, model_run_with_data_rows): model_run, labels = model_run_with_data_rows label_ids = [label.uid for label in labels] @@ -21,22 +20,31 @@ def test_export(self, model_run_with_data_rows): assert isinstance(export_task, ExportTask) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == len(expected_data_rows) + stream_type=StreamType.RESULT + ) == len(expected_data_rows) for data in export_task.get_stream(): obj = json.loads(data.json_str) - assert "media_attributes" in obj and obj[ - "media_attributes"] is not None + assert ( + "media_attributes" in obj + and obj["media_attributes"] is not None + ) exported_model_run = obj["experiments"][model_run.model_id]["runs"][ - model_run.uid] + model_run.uid + ] task_label_ids_set = set( - map(lambda label: label["id"], exported_model_run["labels"])) + map(lambda label: label["id"], exported_model_run["labels"]) + ) task_prediction_ids_set = set( - map(lambda prediction: prediction["id"], - exported_model_run["predictions"])) + map( + lambda prediction: prediction["id"], + exported_model_run["predictions"], + ) + ) for label_id in task_label_ids_set: assert label_id in label_ids for prediction_id in task_prediction_ids_set: diff --git a/libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py index c29239887..818a0178c 100644 --- a/libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py +++ b/libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py @@ -10,16 +10,12 @@ from labelbox.schema.data_row import DataRow from labelbox.schema.label import Label -IMAGE_URL = ( - "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/potato.jpeg" -) +IMAGE_URL = "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/potato.jpeg" class TestExportProject: - @pytest.fixture def project_export(self): - def _project_export(project, task_name, filters=None, params=None): export_task = project.export( task_name=task_name, @@ -55,8 +51,9 @@ def test_export( export_task = project_export(project, task_name, params=params) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) > 0 for data in export_task.get_stream(): @@ -64,8 +61,8 @@ def test_export( task_media_attributes = obj["media_attributes"] task_project = obj["projects"][project.uid] task_project_label_ids_set = set( - map(lambda prediction: prediction["id"], - task_project["labels"])) + map(lambda prediction: prediction["id"], task_project["labels"]) + ) task_project_details = task_project["project_details"] task_data_row = obj["data_row"] task_data_row_details = task_data_row["details"] @@ -84,8 +81,10 @@ def test_export( assert task_data_row_details["created_by"] is not None # media attributes - assert task_media_attributes[ - "mime_type"] == data_row.media_attributes["mimeType"] + assert ( + task_media_attributes["mime_type"] + == data_row.media_attributes["mimeType"] + ) # project name and details assert task_project["name"] == project.name @@ -93,8 +92,10 @@ def test_export( assert task_project_details["batch_id"] == batch.uid assert task_project_details["batch_name"] == batch.name assert task_project_details["priority"] is not None - assert task_project_details[ - "consensus_expected_label_count"] is not None + assert ( + task_project_details["consensus_expected_label_count"] + is not None + ) assert task_project_details["workflow_history"] is not None # label details @@ -125,27 +126,30 @@ def test_with_date_filters( } task_queues = project.task_queues() review_queue = next( - tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE") + tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE" + ) project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid) - export_task = project_export(project, - task_name, - filters=filters, - params=params) + export_task = project_export( + project, task_name, filters=filters, params=params + ) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) > 0 for data in export_task.get_stream(): obj = json.loads(data.json_str) task_project = obj["projects"][project.uid] task_project_label_ids_set = set( - map(lambda prediction: prediction["id"], - task_project["labels"])) + map(lambda prediction: prediction["id"], task_project["labels"]) + ) assert label_id in task_project_label_ids_set - assert task_project["project_details"][ - "workflow_status"] == "IN_REVIEW" + assert ( + task_project["project_details"]["workflow_status"] + == "IN_REVIEW" + ) def test_with_iso_date_filters( self, @@ -160,21 +164,27 @@ def test_with_iso_date_filters( task_name = "TestExportProject:test_with_iso_date_filters" filters = { "last_activity_at": [ - "2000-01-01T00:00:00+0230", "2050-01-01T00:00:00+0230" + "2000-01-01T00:00:00+0230", + "2050-01-01T00:00:00+0230", ], "label_created_at": [ - "2000-01-01T00:00:00+0230", "2050-01-01T00:00:00+0230" + "2000-01-01T00:00:00+0230", + "2050-01-01T00:00:00+0230", ], } export_task = project_export(project, task_name, filters=filters) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) > 0 - assert (label_id == json.loads( - list(export_task.get_stream())[0].json_str)["projects"][project.uid] - ["labels"][0]["id"]) + assert ( + label_id + == json.loads(list(export_task.get_stream())[0].json_str)[ + "projects" + ][project.uid]["labels"][0]["id"] + ) def test_with_iso_date_filters_no_start_date( self, @@ -191,12 +201,16 @@ def test_with_iso_date_filters_no_start_date( export_task = project_export(project, task_name, filters=filters) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) > 0 - assert (label_id == json.loads( - list(export_task.get_stream())[0].json_str)["projects"][project.uid] - ["labels"][0]["id"]) + assert ( + label_id + == json.loads(list(export_task.get_stream())[0].json_str)[ + "projects" + ][project.uid]["labels"][0]["id"] + ) def test_with_iso_date_filters_and_future_start_date( self, @@ -207,24 +221,30 @@ def test_with_iso_date_filters_and_future_start_date( ): project, _, data_row, _label = configured_project_with_label data_row = wait_for_data_row_processing(client, data_row) - task_name = "TestExportProject:test_with_iso_date_filters_and_future_start_date" + task_name = ( + "TestExportProject:test_with_iso_date_filters_and_future_start_date" + ) filters = {"label_created_at": ["2050-01-01T00:00:00+0230", None]} export_task = project_export(project, task_name, filters=filters) assert export_task.has_result() is False assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) is None - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) is None + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) + is None + ) + assert ( + export_task.get_total_lines(stream_type=StreamType.RESULT) is None + ) @pytest.mark.parametrize("data_rows", [3], indirect=True) def test_with_data_row_filter( - self, configured_batch_project_with_multiple_datarows, - project_export): + self, configured_batch_project_with_multiple_datarows, project_export + ): project, _, data_rows = configured_batch_project_with_multiple_datarows datarow_filter_size = 2 - expected_data_row_ids = [dr.uid for dr in data_rows - ][:datarow_filter_size] + expected_data_row_ids = [dr.uid for dr in data_rows][ + :datarow_filter_size + ] task_name = "TestExportProject:test_with_data_row_filter" filters = { "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], @@ -233,32 +253,38 @@ def test_with_data_row_filter( } params = { "data_row_details": True, - "media_type_override": MediaType.Image + "media_type_override": MediaType.Image, } - export_task = project_export(project, - task_name, - filters=filters, - params=params) + export_task = project_export( + project, task_name, filters=filters, params=params + ) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) # only 2 datarows should be exported - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == datarow_filter_size + assert ( + export_task.get_total_lines(stream_type=StreamType.RESULT) + == datarow_filter_size + ) data_row_ids = list( - map(lambda x: json.loads(x.json_str)["data_row"]["id"], - export_task.get_stream())) + map( + lambda x: json.loads(x.json_str)["data_row"]["id"], + export_task.get_stream(), + ) + ) assert data_row_ids.sort() == expected_data_row_ids.sort() @pytest.mark.parametrize("data_rows", [3], indirect=True) def test_with_global_key_filter( - self, configured_batch_project_with_multiple_datarows, - project_export): + self, configured_batch_project_with_multiple_datarows, project_export + ): project, _, data_rows = configured_batch_project_with_multiple_datarows datarow_filter_size = 2 - expected_global_keys = [dr.global_key for dr in data_rows - ][:datarow_filter_size] + expected_global_keys = [dr.global_key for dr in data_rows][ + :datarow_filter_size + ] task_name = "TestExportProject:test_with_global_key_filter" filters = { "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], @@ -267,28 +293,34 @@ def test_with_global_key_filter( } params = { "data_row_details": True, - "media_type_override": MediaType.Image + "media_type_override": MediaType.Image, } - export_task = project_export(project, - task_name, - filters=filters, - params=params) + export_task = project_export( + project, task_name, filters=filters, params=params + ) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) # only 2 datarows should be exported - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == datarow_filter_size + assert ( + export_task.get_total_lines(stream_type=StreamType.RESULT) + == datarow_filter_size + ) global_keys = list( - map(lambda x: json.loads(x.json_str)["data_row"]["global_key"], - export_task.get_stream())) + map( + lambda x: json.loads(x.json_str)["data_row"]["global_key"], + export_task.get_stream(), + ) + ) assert global_keys.sort() == expected_global_keys.sort() def test_batch( self, - configured_batch_project_with_label: Tuple[Project, Dataset, DataRow, - Label], + configured_batch_project_with_label: Tuple[ + Project, Dataset, DataRow, Label + ], dataset: Dataset, image_url: str, project_export, @@ -306,12 +338,12 @@ def test_batch( "media_type_override": MediaType.Image, } task_name = "TestExportProject:test_batch" - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image" - }, - ] * 2) + task = dataset.create_data_rows( + [ + {"row_data": image_url, "external_id": "my-image"}, + ] + * 2 + ) task.wait_till_done() data_rows = [result["id"] for result in task.result] batch_one = f"batch one {uuid.uuid4()}" @@ -320,13 +352,15 @@ def test_batch( # Creatin second batch that will not be used in the export due to the filter: batch_id project.create_batch(batch_one, data_rows) - export_task = project_export(project, - task_name, - filters=filters, - params=params) + export_task = project_export( + project, task_name, filters=filters, params=params + ) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == batch.size + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) + assert ( + export_task.get_total_lines(stream_type=StreamType.RESULT) + == batch.size + ) diff --git a/libs/labelbox/tests/data/export/streamable/test_export_video_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_video_streamable.py index de32509bd..115194a58 100644 --- a/libs/labelbox/tests/data/export/streamable/test_export_video_streamable.py +++ b/libs/labelbox/tests/data/export/streamable/test_export_video_streamable.py @@ -10,7 +10,6 @@ class TestExportVideo: - @pytest.fixture def user_id(self, client): return client.get_user().uid @@ -41,12 +40,15 @@ def test_export( for data_row_uid in data_row_uids: labels = [ - lb_types.Label(data=VideoData(uid=data_row_uid), - annotations=bbox_video_annotation_objects) + lb_types.Label( + data=VideoData(uid=data_row_uid), + annotations=bbox_video_annotation_objects, + ) ] label_import = lb.LabelImport.create_from_objects( - client, project_id, f"test-import-{project_id}", labels) + client, project_id, f"test-import-{project_id}", labels + ) label_import.wait_until_done() assert label_import.state == AnnotationImportState.FINISHED @@ -65,18 +67,21 @@ def test_export( assert isinstance(export_task, ExportTask) assert export_task.has_result() assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 + assert ( + export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 + ) export_data = json.loads(list(export_task.get_stream())[0].json_str) data_row_export = export_data["data_row"] assert data_row_export["global_key"] == video_data_row["global_key"] assert data_row_export["row_data"] == video_data_row["row_data"] assert export_data["media_attributes"]["mime_type"] == "video/mp4" - assert export_data["media_attributes"][ - "frame_rate"] == 10 # as per the video_data fixture - assert (export_data["media_attributes"]["frame_count"] == 100 - ) # as per the video_data fixture + assert ( + export_data["media_attributes"]["frame_rate"] == 10 + ) # as per the video_data fixture + assert ( + export_data["media_attributes"]["frame_count"] == 100 + ) # as per the video_data fixture expected_export_label = { "label_kind": "Video", "version": "1.0.0", @@ -96,17 +101,17 @@ def test_export( "feature_id": "clgjnpyse000ui3zx6fr1d880", "name": "bbox", "annotation_kind": "VideoBoundingBox", - "classifications": [{ - "feature_id": "clgjnpyse000vi3zxtgtfh01y", - "name": "nested", - "radio_answer": { - "feature_id": - "clgjnpyse000wi3zxnxgv53ps", - "name": - "radio_option_1", - "classifications": [], - }, - }], + "classifications": [ + { + "feature_id": "clgjnpyse000vi3zxtgtfh01y", + "name": "nested", + "radio_answer": { + "feature_id": "clgjnpyse000wi3zxnxgv53ps", + "name": "radio_option_1", + "classifications": [], + }, + } + ], "bounding_box": { "top": 98.0, "left": 146.0, @@ -123,17 +128,17 @@ def test_export( "feature_id": "clgjnpyse000ui3zx6fr1d880", "name": "bbox", "annotation_kind": "VideoBoundingBox", - "classifications": [{ - "feature_id": "clgjnpyse000vi3zxtgtfh01y", - "name": "nested", - "radio_answer": { - "feature_id": - "clgjnpyse000wi3zxnxgv53ps", - "name": - "radio_option_1", - "classifications": [], - }, - }], + "classifications": [ + { + "feature_id": "clgjnpyse000vi3zxtgtfh01y", + "name": "nested", + "radio_answer": { + "feature_id": "clgjnpyse000wi3zxnxgv53ps", + "name": "radio_option_1", + "classifications": [], + }, + } + ], "bounding_box": { "top": 98.0, "left": 146.0, @@ -162,14 +167,12 @@ def test_export( "classifications": [], }, }, - "segments": { - "clgjnpyse000ui3zx6fr1d880": [[13, 13], [18, 19]] - }, + "segments": {"clgjnpyse000ui3zx6fr1d880": [[13, 13], [18, 19]]}, "key_frame_feature_map": { "clgjnpyse000ui3zx6fr1d880": { "13": True, "18": False, - "19": True + "19": True, } }, "classifications": [], @@ -183,8 +186,9 @@ def test_export( export_label = project_export_labels[0] assert (export_label["label_kind"]) == "Video" - assert (export_label["label_details"].keys() - ) == expected_export_label["label_details"].keys() + assert (export_label["label_details"].keys()) == expected_export_label[ + "label_details" + ].keys() expected_frames_ids = [ vannotation.frame for vannotation in bbox_video_annotation_objects @@ -193,9 +197,7 @@ def test_export( export_frames = export_annotations["frames"] export_frames_ids = [int(frame_id) for frame_id in export_frames.keys()] all_frames_exported = [] - for (value) in ( - expected_frames_ids - ): # note need to understand why we are exporting more frames than we created + for value in expected_frames_ids: # note need to understand why we are exporting more frames than we created if value not in export_frames_ids: all_frames_exported.append(value) assert len(all_frames_exported) == 0 @@ -216,15 +218,23 @@ def test_export( # Since the bounding box moves to the right, the interpolated frame content should start # a little bit more far to the right - assert (export_frames[str(first_frame_id + 1)]["objects"] - [first_exported_label_id]["bounding_box"]["left"] - > export_frames[str(first_frame_id)]["objects"] - [first_exported_label_id]["bounding_box"]["left"]) + assert ( + export_frames[str(first_frame_id + 1)]["objects"][ + first_exported_label_id + ]["bounding_box"]["left"] + > export_frames[str(first_frame_id)]["objects"][ + first_exported_label_id + ]["bounding_box"]["left"] + ) # But it shouldn't be further than the last frame - assert (export_frames[str(first_frame_id + 1)]["objects"] - [first_exported_label_id]["bounding_box"]["left"] - < export_frames[str(last_frame_id)]["objects"] - [first_exported_label_id]["bounding_box"]["left"]) + assert ( + export_frames[str(first_frame_id + 1)]["objects"][ + first_exported_label_id + ]["bounding_box"]["left"] + < export_frames[str(last_frame_id)]["objects"][ + first_exported_label_id + ]["bounding_box"]["left"] + ) # END OF THE VIDEO INTERPOLATION ASSERTIONS frame_with_nested_classifications = export_frames["13"] diff --git a/libs/labelbox/tests/data/metrics/confusion_matrix/conftest.py b/libs/labelbox/tests/data/metrics/confusion_matrix/conftest.py index ce82ff21d..c61c4f1df 100644 --- a/libs/labelbox/tests/data/metrics/confusion_matrix/conftest.py +++ b/libs/labelbox/tests/data/metrics/confusion_matrix/conftest.py @@ -2,30 +2,47 @@ import pytest -from labelbox.data.annotation_types import ClassificationAnnotation, ObjectAnnotation -from labelbox.data.annotation_types import Polygon, Point, Rectangle, Mask, MaskData, Line, Radio, Text, Checklist, ClassificationAnswer +from labelbox.data.annotation_types import ( + ClassificationAnnotation, + ObjectAnnotation, +) +from labelbox.data.annotation_types import ( + Polygon, + Point, + Rectangle, + Mask, + MaskData, + Line, + Radio, + Text, + Checklist, + ClassificationAnswer, +) import numpy as np from labelbox.data.annotation_types.ner import TextEntity class NameSpace(SimpleNamespace): - - def __init__(self, - predictions, - ground_truths, - expected, - expected_without_subclasses=None): + def __init__( + self, + predictions, + ground_truths, + expected, + expected_without_subclasses=None, + ): super(NameSpace, self).__init__( predictions=predictions, ground_truths=ground_truths, expected=expected, - expected_without_subclasses=expected_without_subclasses or expected) + expected_without_subclasses=expected_without_subclasses or expected, + ) def get_radio(name, answer_name): return ClassificationAnnotation( - name=name, value=Radio(answer=ClassificationAnswer(name=answer_name))) + name=name, value=Radio(answer=ClassificationAnswer(name=answer_name)) + ) def get_text(name, text_content): @@ -33,26 +50,33 @@ def get_text(name, text_content): def get_checklist(name, answer_names): - return ClassificationAnnotation(name=name, - value=Radio(answer=[ - ClassificationAnswer(name=answer_name) - for answer_name in answer_names - ])) + return ClassificationAnnotation( + name=name, + value=Radio( + answer=[ + ClassificationAnswer(name=answer_name) + for answer_name in answer_names + ] + ), + ) def get_polygon(name, points, subclasses=None): return ObjectAnnotation( name=name, value=Polygon(points=[Point(x=x, y=y) for x, y in points]), - classifications=[] if subclasses is None else subclasses) + classifications=[] if subclasses is None else subclasses, + ) def get_rectangle(name, start, end, subclasses=None): return ObjectAnnotation( name=name, - value=Rectangle(start=Point(x=start[0], y=start[1]), - end=Point(x=end[0], y=end[1])), - classifications=[] if subclasses is None else subclasses) + value=Rectangle( + start=Point(x=start[0], y=start[1]), end=Point(x=end[0], y=end[1]) + ), + classifications=[] if subclasses is None else subclasses, + ) def get_mask(name, pixels, color=(1, 1, 1), subclasses=None): @@ -62,272 +86,325 @@ def get_mask(name, pixels, color=(1, 1, 1), subclasses=None): return ObjectAnnotation( name=name, value=Mask(mask=MaskData(arr=mask), color=color), - classifications=[] if subclasses is None else subclasses) + classifications=[] if subclasses is None else subclasses, + ) def get_line(name, points, subclasses=None): return ObjectAnnotation( name=name, value=Line(points=[Point(x=x, y=y) for x, y in points]), - classifications=[] if subclasses is None else subclasses) + classifications=[] if subclasses is None else subclasses, + ) def get_point(name, x, y, subclasses=None): return ObjectAnnotation( name=name, value=Point(x=x, y=y), - classifications=[] if subclasses is None else subclasses) + classifications=[] if subclasses is None else subclasses, + ) def get_radio(name, answer_name): return ClassificationAnnotation( - name=name, value=Radio(answer=ClassificationAnswer(name=answer_name))) + name=name, value=Radio(answer=ClassificationAnswer(name=answer_name)) + ) def get_checklist(name, answer_names): - return ClassificationAnnotation(name=name, - value=Checklist(answer=[ - ClassificationAnswer(name=answer_name) - for answer_name in answer_names - ])) + return ClassificationAnnotation( + name=name, + value=Checklist( + answer=[ + ClassificationAnswer(name=answer_name) + for answer_name in answer_names + ] + ), + ) def get_ner(name, start, end, subclasses=None): return ObjectAnnotation( name=name, value=TextEntity(start=start, end=end), - classifications=[] if subclasses is None else subclasses) + classifications=[] if subclasses is None else subclasses, + ) def get_object_pairs(tool_fn, **kwargs): return [ - NameSpace(predictions=[tool_fn("cat", **kwargs)], - ground_truths=[tool_fn("cat", **kwargs)], - expected={'cat': [1, 0, 0, 0]}), + NameSpace( + predictions=[tool_fn("cat", **kwargs)], + ground_truths=[tool_fn("cat", **kwargs)], + expected={"cat": [1, 0, 0, 0]}, + ), NameSpace( predictions=[ - tool_fn("cat", - **kwargs, - subclasses=[get_radio("is_animal", answer_name="yes")]) + tool_fn( + "cat", + **kwargs, + subclasses=[get_radio("is_animal", answer_name="yes")], + ) ], ground_truths=[ - tool_fn("cat", - **kwargs, - subclasses=[get_radio("is_animal", answer_name="yes")]) + tool_fn( + "cat", + **kwargs, + subclasses=[get_radio("is_animal", answer_name="yes")], + ) ], - expected={'cat': [1, 0, 0, 0]}, - expected_without_subclasses={'cat': [1, 0, 0, 0]}), - NameSpace(predictions=[ - tool_fn("cat", + expected={"cat": [1, 0, 0, 0]}, + expected_without_subclasses={"cat": [1, 0, 0, 0]}, + ), + NameSpace( + predictions=[ + tool_fn( + "cat", **kwargs, - subclasses=[get_radio("is_animal", answer_name="yes")]) - ], - ground_truths=[ - tool_fn( - "cat", - **kwargs, - subclasses=[get_radio("is_animal", answer_name="no")]) - ], - expected={'cat': [0, 1, 0, 1]}, - expected_without_subclasses={'cat': [1, 0, 0, 0]}), - NameSpace(predictions=[ - tool_fn("cat", + subclasses=[get_radio("is_animal", answer_name="yes")], + ) + ], + ground_truths=[ + tool_fn( + "cat", **kwargs, - subclasses=[get_radio("is_animal", answer_name="yes")]), - tool_fn("cat", + subclasses=[get_radio("is_animal", answer_name="no")], + ) + ], + expected={"cat": [0, 1, 0, 1]}, + expected_without_subclasses={"cat": [1, 0, 0, 0]}, + ), + NameSpace( + predictions=[ + tool_fn( + "cat", **kwargs, - subclasses=[get_radio("is_animal", answer_name="no")]) - ], - ground_truths=[ - tool_fn( - "cat", - **kwargs, - subclasses=[get_radio("is_animal", answer_name="no")]) - ], - expected={'cat': [1, 1, 0, 0]}, - expected_without_subclasses={'cat': [1, 1, 0, 0]}), - NameSpace(predictions=[ - tool_fn("cat", + subclasses=[get_radio("is_animal", answer_name="yes")], + ), + tool_fn( + "cat", **kwargs, - subclasses=[get_radio("is_animal", answer_name="yes")]), - tool_fn("dog", + subclasses=[get_radio("is_animal", answer_name="no")], + ), + ], + ground_truths=[ + tool_fn( + "cat", **kwargs, - subclasses=[get_radio("is_animal", answer_name="no")]) - ], - ground_truths=[ - tool_fn( - "cat", - **kwargs, - subclasses=[get_radio("is_animal", answer_name="no")]) - ], - expected={ - 'cat': [0, 1, 0, 1], - 'dog': [0, 1, 0, 0] - }, - expected_without_subclasses={ - 'cat': [1, 0, 0, 0], - 'dog': [0, 1, 0, 0] - }), + subclasses=[get_radio("is_animal", answer_name="no")], + ) + ], + expected={"cat": [1, 1, 0, 0]}, + expected_without_subclasses={"cat": [1, 1, 0, 0]}, + ), NameSpace( - predictions=[tool_fn("cat", **kwargs), - tool_fn("cat", **kwargs)], - ground_truths=[tool_fn("cat", **kwargs), - tool_fn("cat", **kwargs)], - expected={'cat': [2, 0, 0, 0]}), + predictions=[ + tool_fn( + "cat", + **kwargs, + subclasses=[get_radio("is_animal", answer_name="yes")], + ), + tool_fn( + "dog", + **kwargs, + subclasses=[get_radio("is_animal", answer_name="no")], + ), + ], + ground_truths=[ + tool_fn( + "cat", + **kwargs, + subclasses=[get_radio("is_animal", answer_name="no")], + ) + ], + expected={"cat": [0, 1, 0, 1], "dog": [0, 1, 0, 0]}, + expected_without_subclasses={ + "cat": [1, 0, 0, 0], + "dog": [0, 1, 0, 0], + }, + ), + NameSpace( + predictions=[tool_fn("cat", **kwargs), tool_fn("cat", **kwargs)], + ground_truths=[tool_fn("cat", **kwargs), tool_fn("cat", **kwargs)], + expected={"cat": [2, 0, 0, 0]}, + ), NameSpace( - predictions=[tool_fn("cat", **kwargs), - tool_fn("cat", **kwargs)], + predictions=[tool_fn("cat", **kwargs), tool_fn("cat", **kwargs)], ground_truths=[tool_fn("cat", **kwargs)], - expected={'cat': [1, 1, 0, 0]}), + expected={"cat": [1, 1, 0, 0]}, + ), + NameSpace( + predictions=[tool_fn("cat", **kwargs)], + ground_truths=[tool_fn("cat", **kwargs), tool_fn("cat", **kwargs)], + expected={"cat": [1, 0, 0, 1]}, + ), + NameSpace( + predictions=[], + ground_truths=[], + expected=[], + expected_without_subclasses=[], + ), + NameSpace( + predictions=[], + ground_truths=[tool_fn("cat", **kwargs)], + expected={"cat": [0, 0, 0, 1]}, + ), + NameSpace( + predictions=[tool_fn("cat", **kwargs)], + ground_truths=[], + expected={"cat": [0, 1, 0, 0]}, + ), NameSpace( predictions=[tool_fn("cat", **kwargs)], - ground_truths=[tool_fn("cat", **kwargs), - tool_fn("cat", **kwargs)], - expected={'cat': [1, 0, 0, 1]}), - NameSpace(predictions=[], - ground_truths=[], - expected=[], - expected_without_subclasses=[]), - NameSpace(predictions=[], - ground_truths=[tool_fn("cat", **kwargs)], - expected={'cat': [0, 0, 0, 1]}), - NameSpace(predictions=[tool_fn("cat", **kwargs)], - ground_truths=[], - expected={'cat': [0, 1, 0, 0]}), - NameSpace(predictions=[tool_fn("cat", **kwargs)], - ground_truths=[tool_fn("dog", **kwargs)], - expected={ - 'cat': [0, 1, 0, 0], - 'dog': [0, 0, 0, 1] - }) + ground_truths=[tool_fn("dog", **kwargs)], + expected={"cat": [0, 1, 0, 0], "dog": [0, 0, 0, 1]}, + ), ] @pytest.fixture def radio_pairs(): return [ - NameSpace(predictions=[get_radio("is_animal", answer_name="yes")], - ground_truths=[get_radio("is_animal", answer_name="yes")], - expected={'yes': [1, 0, 0, 0]}), - NameSpace(predictions=[get_radio("is_animal", answer_name="yes")], - ground_truths=[get_radio("is_animal", answer_name="no")], - expected={ - 'no': [0, 0, 0, 1], - 'yes': [0, 1, 0, 0] - }), - NameSpace(predictions=[get_radio("is_animal", answer_name="yes")], - ground_truths=[], - expected={'yes': [0, 1, 0, 0]}), - NameSpace(predictions=[], - ground_truths=[get_radio("is_animal", answer_name="yes")], - expected={'yes': [0, 0, 0, 1]}), - NameSpace(predictions=[ - get_radio("is_animal", answer_name="yes"), - get_radio("is_short", answer_name="no") - ], - ground_truths=[get_radio("is_animal", answer_name="yes")], - expected={ - 'no': [0, 1, 0, 0], - 'yes': [1, 0, 0, 0] - }), - #Not supported yet: + NameSpace( + predictions=[get_radio("is_animal", answer_name="yes")], + ground_truths=[get_radio("is_animal", answer_name="yes")], + expected={"yes": [1, 0, 0, 0]}, + ), + NameSpace( + predictions=[get_radio("is_animal", answer_name="yes")], + ground_truths=[get_radio("is_animal", answer_name="no")], + expected={"no": [0, 0, 0, 1], "yes": [0, 1, 0, 0]}, + ), + NameSpace( + predictions=[get_radio("is_animal", answer_name="yes")], + ground_truths=[], + expected={"yes": [0, 1, 0, 0]}, + ), + NameSpace( + predictions=[], + ground_truths=[get_radio("is_animal", answer_name="yes")], + expected={"yes": [0, 0, 0, 1]}, + ), + NameSpace( + predictions=[ + get_radio("is_animal", answer_name="yes"), + get_radio("is_short", answer_name="no"), + ], + ground_truths=[get_radio("is_animal", answer_name="yes")], + expected={"no": [0, 1, 0, 0], "yes": [1, 0, 0, 0]}, + ), + # Not supported yet: # NameSpace( - #predictions=[], - #ground_truths=[], - #expected = [0,0,1,0] - #) + # predictions=[], + # ground_truths=[], + # expected = [0,0,1,0] + # ) ] @pytest.fixture def checklist_pairs(): return [ - NameSpace(predictions=[ - get_checklist("animal_attributes", answer_names=["striped"]) - ], - ground_truths=[ - get_checklist("animal_attributes", - answer_names=["striped"]) - ], - expected={'striped': [1, 0, 0, 0]}), - NameSpace(predictions=[ - get_checklist("animal_attributes", answer_names=["striped"]) - ], - ground_truths=[], - expected={'striped': [0, 1, 0, 0]}), - NameSpace(predictions=[], - ground_truths=[ - get_checklist("animal_attributes", - answer_names=["striped"]) - ], - expected={'striped': [0, 0, 0, 1]}), - NameSpace(predictions=[ - get_checklist("animal_attributes", - answer_names=["striped", "short"]) - ], - ground_truths=[ - get_checklist("animal_attributes", - answer_names=["striped"]) - ], - expected={ - 'short': [0, 1, 0, 0], - 'striped': [1, 0, 0, 0] - }), - NameSpace(predictions=[ - get_checklist("animal_attributes", answer_names=["striped"]) - ], - ground_truths=[ - get_checklist("animal_attributes", - answer_names=["striped", "short"]) - ], - expected={ - 'short': [0, 0, 0, 1], - 'striped': [1, 0, 0, 0] - }), - NameSpace(predictions=[ - get_checklist("animal_attributes", - answer_names=["striped", "short", "black"]) - ], - ground_truths=[ - get_checklist("animal_attributes", - answer_names=["striped", "short"]) - ], - expected={ - 'black': [0, 1, 0, 0], - 'short': [1, 0, 0, 0], - 'striped': [1, 0, 0, 0] - }), - NameSpace(predictions=[ - get_checklist("animal_attributes", - answer_names=["striped", "short", "black"]), - get_checklist("animal_name", answer_names=["doggy", "pup"]) - ], - ground_truths=[ - get_checklist("animal_attributes", - answer_names=["striped", "short"]), - get_checklist("animal_name", answer_names=["pup"]) - ], - expected={ - 'black': [0, 1, 0, 0], - 'doggy': [0, 1, 0, 0], - 'pup': [1, 0, 0, 0], - 'short': [1, 0, 0, 0], - 'striped': [1, 0, 0, 0] - }) - - #Not supported yet: + NameSpace( + predictions=[ + get_checklist("animal_attributes", answer_names=["striped"]) + ], + ground_truths=[ + get_checklist("animal_attributes", answer_names=["striped"]) + ], + expected={"striped": [1, 0, 0, 0]}, + ), + NameSpace( + predictions=[ + get_checklist("animal_attributes", answer_names=["striped"]) + ], + ground_truths=[], + expected={"striped": [0, 1, 0, 0]}, + ), + NameSpace( + predictions=[], + ground_truths=[ + get_checklist("animal_attributes", answer_names=["striped"]) + ], + expected={"striped": [0, 0, 0, 1]}, + ), + NameSpace( + predictions=[ + get_checklist( + "animal_attributes", answer_names=["striped", "short"] + ) + ], + ground_truths=[ + get_checklist("animal_attributes", answer_names=["striped"]) + ], + expected={"short": [0, 1, 0, 0], "striped": [1, 0, 0, 0]}, + ), + NameSpace( + predictions=[ + get_checklist("animal_attributes", answer_names=["striped"]) + ], + ground_truths=[ + get_checklist( + "animal_attributes", answer_names=["striped", "short"] + ) + ], + expected={"short": [0, 0, 0, 1], "striped": [1, 0, 0, 0]}, + ), + NameSpace( + predictions=[ + get_checklist( + "animal_attributes", + answer_names=["striped", "short", "black"], + ) + ], + ground_truths=[ + get_checklist( + "animal_attributes", answer_names=["striped", "short"] + ) + ], + expected={ + "black": [0, 1, 0, 0], + "short": [1, 0, 0, 0], + "striped": [1, 0, 0, 0], + }, + ), + NameSpace( + predictions=[ + get_checklist( + "animal_attributes", + answer_names=["striped", "short", "black"], + ), + get_checklist("animal_name", answer_names=["doggy", "pup"]), + ], + ground_truths=[ + get_checklist( + "animal_attributes", answer_names=["striped", "short"] + ), + get_checklist("animal_name", answer_names=["pup"]), + ], + expected={ + "black": [0, 1, 0, 0], + "doggy": [0, 1, 0, 0], + "pup": [1, 0, 0, 0], + "short": [1, 0, 0, 0], + "striped": [1, 0, 0, 0], + }, + ), + # Not supported yet: # NameSpace( - #predictions=[], - #ground_truths=[], - #expected = [0,0,1,0] - #) + # predictions=[], + # ground_truths=[], + # expected = [0,0,1,0] + # ) ] @pytest.fixture def polygon_pairs(): - return get_object_pairs(get_polygon, - points=[[0, 0], [10, 0], [10, 10], [0, 10]]) + return get_object_pairs( + get_polygon, points=[[0, 0], [10, 0], [10, 10], [0, 10]] + ) @pytest.fixture @@ -342,8 +419,9 @@ def mask_pairs(): @pytest.fixture def line_pairs(): - return get_object_pairs(get_line, - points=[[0, 0], [10, 0], [10, 10], [0, 10]]) + return get_object_pairs( + get_line, points=[[0, 0], [10, 0], [10, 10], [0, 10]] + ) @pytest.fixture @@ -359,47 +437,39 @@ def ner_pairs(): @pytest.fixture() def pair_iou_thresholds(): return [ - NameSpace(predictions=[ - get_polygon("cat", points=[[0, 0], [10, 0], [10, 10], [0, 10]]), - ], - ground_truths=[ - get_polygon("cat", - points=[[0, 0], [5, 0], [5, 5], [0, 5]]), - ], - expected={ - 0.2: [1, 0, 0, 0], - 0.3: [0, 1, 0, 1] - }), + NameSpace( + predictions=[ + get_polygon("cat", points=[[0, 0], [10, 0], [10, 10], [0, 10]]), + ], + ground_truths=[ + get_polygon("cat", points=[[0, 0], [5, 0], [5, 5], [0, 5]]), + ], + expected={0.2: [1, 0, 0, 0], 0.3: [0, 1, 0, 1]}, + ), NameSpace( predictions=[get_rectangle("cat", start=[0, 0], end=[10, 10])], ground_truths=[get_rectangle("cat", start=[0, 0], end=[5, 5])], - expected={ - 0.2: [1, 0, 0, 0], - 0.3: [0, 1, 0, 1] - }), - NameSpace(predictions=[get_point("cat", x=0, y=0)], - ground_truths=[get_point("cat", x=20, y=20)], - expected={ - 0.5: [1, 0, 0, 0], - 0.65: [0, 1, 0, 1] - }), - NameSpace(predictions=[ - get_line("cat", points=[[0, 0], [10, 0], [10, 10], [0, 10]]) - ], - ground_truths=[ - get_line("cat", - points=[[0, 0], [100, 0], [100, 100], [0, 100]]) - ], - expected={ - 0.3: [1, 0, 0, 0], - 0.65: [0, 1, 0, 1] - }), - NameSpace(predictions=[ - get_mask("cat", pixels=[[0, 0], [1, 1], [2, 2], [3, 3]]) - ], - ground_truths=[get_mask("cat", pixels=[[0, 0], [1, 1]])], - expected={ - 0.4: [1, 0, 0, 0], - 0.6: [0, 1, 0, 1] - }), + expected={0.2: [1, 0, 0, 0], 0.3: [0, 1, 0, 1]}, + ), + NameSpace( + predictions=[get_point("cat", x=0, y=0)], + ground_truths=[get_point("cat", x=20, y=20)], + expected={0.5: [1, 0, 0, 0], 0.65: [0, 1, 0, 1]}, + ), + NameSpace( + predictions=[ + get_line("cat", points=[[0, 0], [10, 0], [10, 10], [0, 10]]) + ], + ground_truths=[ + get_line("cat", points=[[0, 0], [100, 0], [100, 100], [0, 100]]) + ], + expected={0.3: [1, 0, 0, 0], 0.65: [0, 1, 0, 1]}, + ), + NameSpace( + predictions=[ + get_mask("cat", pixels=[[0, 0], [1, 1], [2, 2], [3, 3]]) + ], + ground_truths=[get_mask("cat", pixels=[[0, 0], [1, 1]])], + expected={0.4: [1, 0, 0, 0], 0.6: [0, 1, 0, 1]}, + ), ] diff --git a/libs/labelbox/tests/data/metrics/confusion_matrix/test_confusion_matrix_data_row.py b/libs/labelbox/tests/data/metrics/confusion_matrix/test_confusion_matrix_data_row.py index e84207ac2..e3ac86213 100644 --- a/libs/labelbox/tests/data/metrics/confusion_matrix/test_confusion_matrix_data_row.py +++ b/libs/labelbox/tests/data/metrics/confusion_matrix/test_confusion_matrix_data_row.py @@ -1,47 +1,57 @@ from pytest_cases import fixture_ref from pytest_cases import parametrize, fixture_ref -from labelbox.data.metrics.confusion_matrix.confusion_matrix import confusion_matrix_metric +from labelbox.data.metrics.confusion_matrix.confusion_matrix import ( + confusion_matrix_metric, +) -@parametrize("tool_examples", [ - fixture_ref('polygon_pairs'), - fixture_ref('rectangle_pairs'), - fixture_ref('mask_pairs'), - fixture_ref('line_pairs'), - fixture_ref('point_pairs'), - fixture_ref('ner_pairs') -]) +@parametrize( + "tool_examples", + [ + fixture_ref("polygon_pairs"), + fixture_ref("rectangle_pairs"), + fixture_ref("mask_pairs"), + fixture_ref("line_pairs"), + fixture_ref("point_pairs"), + fixture_ref("ner_pairs"), + ], +) def test_overlapping_objects(tool_examples): for example in tool_examples: - - for include_subclasses, expected_attr_name in [[ - True, 'expected' - ], [False, 'expected_without_subclasses']]: + for include_subclasses, expected_attr_name in [ + [True, "expected"], + [False, "expected_without_subclasses"], + ]: score = confusion_matrix_metric( example.ground_truths, example.predictions, - include_subclasses=include_subclasses) + include_subclasses=include_subclasses, + ) if len(getattr(example, expected_attr_name)) == 0: assert len(score) == 0 else: expected = [0, 0, 0, 0] - for expected_values in getattr(example, - expected_attr_name).values(): + for expected_values in getattr( + example, expected_attr_name + ).values(): for idx in range(4): expected[idx] += expected_values[idx] assert score[0].value == tuple( - expected), f"{example.predictions},{example.ground_truths}" + expected + ), f"{example.predictions},{example.ground_truths}" -@parametrize("tool_examples", - [fixture_ref('checklist_pairs'), - fixture_ref('radio_pairs')]) +@parametrize( + "tool_examples", + [fixture_ref("checklist_pairs"), fixture_ref("radio_pairs")], +) def test_overlapping_classifications(tool_examples): for example in tool_examples: - score = confusion_matrix_metric(example.ground_truths, - example.predictions) + score = confusion_matrix_metric( + example.ground_truths, example.predictions + ) if len(example.expected) == 0: assert len(score) == 0 else: @@ -50,15 +60,16 @@ def test_overlapping_classifications(tool_examples): for idx in range(4): expected[idx] += expected_values[idx] assert score[0].value == tuple( - expected), f"{example.predictions},{example.ground_truths}" + expected + ), f"{example.predictions},{example.ground_truths}" def test_partial_overlap(pair_iou_thresholds): for example in pair_iou_thresholds: for iou in example.expected.keys(): - score = confusion_matrix_metric(example.predictions, - example.ground_truths, - iou=iou) + score = confusion_matrix_metric( + example.predictions, example.ground_truths, iou=iou + ) assert score[0].value == tuple( example.expected[iou] ), f"{example.predictions},{example.ground_truths}" diff --git a/libs/labelbox/tests/data/metrics/confusion_matrix/test_confusion_matrix_feature.py b/libs/labelbox/tests/data/metrics/confusion_matrix/test_confusion_matrix_feature.py index f55555e75..818c01f72 100644 --- a/libs/labelbox/tests/data/metrics/confusion_matrix/test_confusion_matrix_feature.py +++ b/libs/labelbox/tests/data/metrics/confusion_matrix/test_confusion_matrix_feature.py @@ -1,26 +1,33 @@ from pytest_cases import fixture_ref from pytest_cases import parametrize, fixture_ref -from labelbox.data.metrics.confusion_matrix.confusion_matrix import feature_confusion_matrix_metric - - -@parametrize("tool_examples", [ - fixture_ref('polygon_pairs'), - fixture_ref('rectangle_pairs'), - fixture_ref('mask_pairs'), - fixture_ref('line_pairs'), - fixture_ref('point_pairs'), - fixture_ref('ner_pairs') -]) +from labelbox.data.metrics.confusion_matrix.confusion_matrix import ( + feature_confusion_matrix_metric, +) + + +@parametrize( + "tool_examples", + [ + fixture_ref("polygon_pairs"), + fixture_ref("rectangle_pairs"), + fixture_ref("mask_pairs"), + fixture_ref("line_pairs"), + fixture_ref("point_pairs"), + fixture_ref("ner_pairs"), + ], +) def test_overlapping_objects(tool_examples): for example in tool_examples: - for include_subclasses, expected_attr_name in [[ - True, 'expected' - ], [False, 'expected_without_subclasses']]: + for include_subclasses, expected_attr_name in [ + [True, "expected"], + [False, "expected_without_subclasses"], + ]: metrics = feature_confusion_matrix_metric( example.ground_truths, example.predictions, - include_subclasses=include_subclasses) + include_subclasses=include_subclasses, + ) metrics = {r.feature_name: list(r.value) for r in metrics} if len(getattr(example, expected_attr_name)) == 0: @@ -31,17 +38,20 @@ def test_overlapping_objects(tool_examples): ), f"{example.predictions},{example.ground_truths}" -@parametrize("tool_examples", - [fixture_ref('checklist_pairs'), - fixture_ref('radio_pairs')]) +@parametrize( + "tool_examples", + [fixture_ref("checklist_pairs"), fixture_ref("radio_pairs")], +) def test_overlapping_classifications(tool_examples): for example in tool_examples: - - metrics = feature_confusion_matrix_metric(example.ground_truths, - example.predictions) + metrics = feature_confusion_matrix_metric( + example.ground_truths, example.predictions + ) metrics = {r.feature_name: list(r.value) for r in metrics} if len(example.expected) == 0: assert len(metrics) == 0 else: - assert metrics == example.expected, f"{example.predictions},{example.ground_truths}" + assert ( + metrics == example.expected + ), f"{example.predictions},{example.ground_truths}" diff --git a/libs/labelbox/tests/data/metrics/iou/data_row/conftest.py b/libs/labelbox/tests/data/metrics/iou/data_row/conftest.py index d25abe2cf..6614cecf4 100644 --- a/libs/labelbox/tests/data/metrics/iou/data_row/conftest.py +++ b/libs/labelbox/tests/data/metrics/iou/data_row/conftest.py @@ -7,780 +7,696 @@ class NameSpace(SimpleNamespace): - - def __init__(self, - predictions, - labels, - expected, - expected_without_subclasses=None, - data_row_expected=None, - media_attributes=None, - metadata=None, - classifications=None): + def __init__( + self, + predictions, + labels, + expected, + expected_without_subclasses=None, + data_row_expected=None, + media_attributes=None, + metadata=None, + classifications=None, + ): super(NameSpace, self).__init__( predictions=predictions, labels={ - 'DataRow ID': 'ckppihxc10005aeyjen11h7jh', - 'Labeled Data': "https://.jpg", - 'Media Attributes': media_attributes or {}, - 'DataRow Metadata': metadata or [], - 'Label': { - 'objects': labels, - 'classifications': classifications or [] - } + "DataRow ID": "ckppihxc10005aeyjen11h7jh", + "Labeled Data": "https://.jpg", + "Media Attributes": media_attributes or {}, + "DataRow Metadata": metadata or [], + "Label": { + "objects": labels, + "classifications": classifications or [], + }, }, expected=expected, expected_without_subclasses=expected_without_subclasses or expected, - data_row_expected=data_row_expected) + data_row_expected=data_row_expected, + ) @pytest.fixture def polygon_pair(): - return NameSpace(labels=[{ - 'featureId': - 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'polygon': [{ - 'x': 0, - 'y': 0 - }, { - 'x': 1, - 'y': 0 - }, { - 'x': 1, - 'y': 1 - }, { - 'x': 0, - 'y': 1 - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'polygon': [{ - 'x': 0, - 'y': 0 - }, { - 'x': 1, - 'y': 0 - }, { - 'x': 1, - 'y': 0.5 - }, { - 'x': 0, - 'y': 0.5 - }] - }], - expected=0.5) + return NameSpace( + labels=[ + { + "featureId": "ckppivl7p0006aeyj92cezr9d", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "polygon": [ + {"x": 0, "y": 0}, + {"x": 1, "y": 0}, + {"x": 1, "y": 1}, + {"x": 0, "y": 1}, + ], + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "polygon": [ + {"x": 0, "y": 0}, + {"x": 1, "y": 0}, + {"x": 1, "y": 0.5}, + {"x": 0, "y": 0.5}, + ], + } + ], + expected=0.5, + ) @pytest.fixture def box_pair(): - return NameSpace(labels=[{ - 'featureId': 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "bbox": { - "top": 1099, - "left": 2010, - "height": 690, - "width": 591 - } - }], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - "bbox": { - "top": 1099, - "left": 2010, - "height": 690, - "width": 591 - } - }], - expected=1.0) + return NameSpace( + labels=[ + { + "featureId": "ckppivl7p0006aeyj92cezr9d", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "bbox": { + "top": 1099, + "left": 2010, + "height": 690, + "width": 591, + }, + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "bbox": { + "top": 1099, + "left": 2010, + "height": 690, + "width": 591, + }, + } + ], + expected=1.0, + ) @pytest.fixture def unmatched_prediction(): - return NameSpace(labels=[{ - 'featureId': - 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'polygon': [{ - 'x': 0, - 'y': 0 - }, { - 'x': 1, - 'y': 0 - }, { - 'x': 1, - 'y': 1 - }, { - 'x': 0, - 'y': 1 - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'polygon': [{ - 'x': 0, - 'y': 0 - }, { - 'x': 1, - 'y': 0 - }, { - 'x': 1, - 'y': 0.5 - }, { - 'x': 0, - 'y': 0.5 - }] - }, { - 'uuid': - 'd0ba2520-02e9-47d4-8736-088bbdbabbc3', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'polygon': [{ - 'x': 10, - 'y': 10 - }, { - 'x': 11, - 'y': 10 - }, { - 'x': 11, - 'y': 1.5 - }, { - 'x': 10, - 'y': 1.5 - }] - }], - expected=0.25) + return NameSpace( + labels=[ + { + "featureId": "ckppivl7p0006aeyj92cezr9d", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "polygon": [ + {"x": 0, "y": 0}, + {"x": 1, "y": 0}, + {"x": 1, "y": 1}, + {"x": 0, "y": 1}, + ], + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "polygon": [ + {"x": 0, "y": 0}, + {"x": 1, "y": 0}, + {"x": 1, "y": 0.5}, + {"x": 0, "y": 0.5}, + ], + }, + { + "uuid": "d0ba2520-02e9-47d4-8736-088bbdbabbc3", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "polygon": [ + {"x": 10, "y": 10}, + {"x": 11, "y": 10}, + {"x": 11, "y": 1.5}, + {"x": 10, "y": 1.5}, + ], + }, + ], + expected=0.25, + ) @pytest.fixture def unmatched_label(): - return NameSpace(labels=[{ - 'featureId': - '1234567890111213141516171', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'polygon': [{ - 'x': 0, - 'y': 0 - }, { - 'x': 1, - 'y': 0 - }, { - 'x': 1, - 'y': 1 - }, { - 'x': 0, - 'y': 1 - }] - }, { - 'featureId': - 'ckppiw3bs0007aeyjs3pvrqzi', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'polygon': [{ - 'x': 10, - 'y': 10 - }, { - 'x': 11, - 'y': 10 - }, { - 'x': 11, - 'y': 11 - }, { - 'x': 10, - 'y': 11 - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'polygon': [{ - 'x': 0, - 'y': 0 - }, { - 'x': 1, - 'y': 0 - }, { - 'x': 1, - 'y': 0.5 - }, { - 'x': 0, - 'y': 0.5 - }] - }], - expected=0.25) + return NameSpace( + labels=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "polygon": [ + {"x": 0, "y": 0}, + {"x": 1, "y": 0}, + {"x": 1, "y": 1}, + {"x": 0, "y": 1}, + ], + }, + { + "featureId": "ckppiw3bs0007aeyjs3pvrqzi", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "polygon": [ + {"x": 10, "y": 10}, + {"x": 11, "y": 10}, + {"x": 11, "y": 11}, + {"x": 10, "y": 11}, + ], + }, + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "polygon": [ + {"x": 0, "y": 0}, + {"x": 1, "y": 0}, + {"x": 1, "y": 0.5}, + {"x": 0, "y": 0.5}, + ], + } + ], + expected=0.25, + ) def create_mask_url(indices, h, w, value): mask = np.zeros((h, w, 3), dtype=np.uint8) for idx in indices: mask[idx] = value - return base64.b64encode(mask.tobytes()).decode('utf-8') + return base64.b64encode(mask.tobytes()).decode("utf-8") @pytest.fixture def mask_pair(): - return NameSpace(labels=[{ - 'featureId': - '1234567890111213141516171', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'instanceURI': - create_mask_url([(0, 0), (0, 1)], 32, 32, (255, 255, 255)) - }], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'mask': { - 'instanceURI': - create_mask_url([(0, 0)], 32, 32, (1, 1, 1)), - 'colorRGB': (1, 1, 1) - } - }], - expected=0.5) + return NameSpace( + labels=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "instanceURI": create_mask_url( + [(0, 0), (0, 1)], 32, 32, (255, 255, 255) + ), + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "mask": { + "instanceURI": create_mask_url([(0, 0)], 32, 32, (1, 1, 1)), + "colorRGB": (1, 1, 1), + }, + } + ], + expected=0.5, + ) @pytest.fixture def matching_radio(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': '1234567890111213141516171', - 'schemaId': 'ckrm02no8000008l3arwp6h4f', - 'answer': { - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t' - } - }], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckrm02no8000008l3arwp6h4f', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answer': { - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t' - } - }], - expected=1.) + return NameSpace( + labels=[], + classifications=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckrm02no8000008l3arwp6h4f", + "answer": {"schemaId": "ckppid25v0000aeyjmxfwlc7t"}, + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckrm02no8000008l3arwp6h4f", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "answer": {"schemaId": "ckppid25v0000aeyjmxfwlc7t"}, + } + ], + expected=1.0, + ) @pytest.fixture def empty_radio_label(): - return NameSpace(labels=[], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answer': { - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t' - } - }], - expected=0) + return NameSpace( + labels=[], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "answer": {"schemaId": "ckppid25v0000aeyjmxfwlc7t"}, + } + ], + expected=0, + ) @pytest.fixture def empty_radio_prediction(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': '1234567890111213141516171', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': { - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t' - } - }], - predictions=[], - expected=0) + return NameSpace( + labels=[], + classifications=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "answer": {"schemaId": "ckppid25v0000aeyjmxfwlc7t"}, + } + ], + predictions=[], + expected=0, + ) @pytest.fixture def matching_checklist(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': - '1234567890111213141516171', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }, { - 'schemaId': 'ckppidq4u0002aeyjmcc4toxw' - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }, { - 'schemaId': 'ckppidq4u0002aeyjmcc4toxw' - }] - }], - data_row_expected=1., - expected={1.0: 3}) + return NameSpace( + labels=[], + classifications=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "answers": [ + { + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + }, + {"schemaId": "ckppide010001aeyj0yhiaghc"}, + {"schemaId": "ckppidq4u0002aeyjmcc4toxw"}, + ], + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "answers": [ + { + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + }, + {"schemaId": "ckppide010001aeyj0yhiaghc"}, + {"schemaId": "ckppidq4u0002aeyjmcc4toxw"}, + ], + } + ], + data_row_expected=1.0, + expected={1.0: 3}, + ) @pytest.fixture def partially_matching_checklist_1(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': - '1234567890111213141516171', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }, { - 'schemaId': 'ckppidq4u0002aeyjmcc4toxw' - }, { - 'schemaId': 'ckppie29m0003aeyjk1ixzcom' - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }, { - 'schemaId': 'ckppidq4u0002aeyjmcc4toxw' - }, { - 'schemaId': 'ckppiebx80004aeyjuwvos69e' - }] - }], - data_row_expected=0.6, - expected={ - 0.0: 2, - 1.0: 3 - }) + return NameSpace( + labels=[], + classifications=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "answers": [ + { + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + }, + {"schemaId": "ckppide010001aeyj0yhiaghc"}, + {"schemaId": "ckppidq4u0002aeyjmcc4toxw"}, + {"schemaId": "ckppie29m0003aeyjk1ixzcom"}, + ], + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "answers": [ + { + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + }, + {"schemaId": "ckppide010001aeyj0yhiaghc"}, + {"schemaId": "ckppidq4u0002aeyjmcc4toxw"}, + {"schemaId": "ckppiebx80004aeyjuwvos69e"}, + ], + } + ], + data_row_expected=0.6, + expected={0.0: 2, 1.0: 3}, + ) @pytest.fixture def partially_matching_checklist_2(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': - '1234567890111213141516171', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }, { - 'schemaId': 'ckppidq4u0002aeyjmcc4toxw' - }, { - 'schemaId': 'ckppiebx80004aeyjuwvos69e' - }] - }], - data_row_expected=0.5, - expected={ - 1.0: 2, - 0.0: 2 - }) + return NameSpace( + labels=[], + classifications=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "answers": [ + { + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + }, + {"schemaId": "ckppide010001aeyj0yhiaghc"}, + ], + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "answers": [ + { + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + }, + {"schemaId": "ckppide010001aeyj0yhiaghc"}, + {"schemaId": "ckppidq4u0002aeyjmcc4toxw"}, + {"schemaId": "ckppiebx80004aeyjuwvos69e"}, + ], + } + ], + data_row_expected=0.5, + expected={1.0: 2, 0.0: 2}, + ) @pytest.fixture def partially_matching_checklist_3(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': - '1234567890111213141516171', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }, { - 'schemaId': 'ckppidq4u0002aeyjmcc4toxw' - }, { - 'schemaId': 'ckppiebx80004aeyjuwvos69e' - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }] - }], - data_row_expected=0.5, - expected={ - 1.0: 2, - 0.0: 2 - }) + return NameSpace( + labels=[], + classifications=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "answers": [ + { + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + }, + {"schemaId": "ckppide010001aeyj0yhiaghc"}, + {"schemaId": "ckppidq4u0002aeyjmcc4toxw"}, + {"schemaId": "ckppiebx80004aeyjuwvos69e"}, + ], + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "answers": [ + { + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + }, + {"schemaId": "ckppide010001aeyj0yhiaghc"}, + ], + } + ], + data_row_expected=0.5, + expected={1.0: 2, 0.0: 2}, + ) @pytest.fixture def empty_checklist_label(): - return NameSpace(labels=[], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t' - }] - }], - data_row_expected=0.0, - expected={0.0: 1}) + return NameSpace( + labels=[], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "answers": [{"schemaId": "ckppid25v0000aeyjmxfwlc7t"}], + } + ], + data_row_expected=0.0, + expected={0.0: 1}, + ) @pytest.fixture def empty_checklist_prediction(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': '1234567890111213141516171', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t' - }] - }], - predictions=[], - data_row_expected=0.0, - expected={0.0: 1}) + return NameSpace( + labels=[], + classifications=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "answers": [{"schemaId": "ckppid25v0000aeyjmxfwlc7t"}], + } + ], + predictions=[], + data_row_expected=0.0, + expected={0.0: 1}, + ) @pytest.fixture def matching_text(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': '1234567890111213141516171', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': 'test' - }], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answer': 'test' - }], - expected=1.0) + return NameSpace( + labels=[], + classifications=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "answer": "test", + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "answer": "test", + } + ], + expected=1.0, + ) @pytest.fixture def not_matching_text(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': '1234567890111213141516171', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': 'test' - }], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answer': 'not_test' - }], - expected=0.) + return NameSpace( + labels=[], + classifications=[ + { + "featureId": "1234567890111213141516171", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "answer": "test", + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "answer": "not_test", + } + ], + expected=0.0, + ) @pytest.fixture def test_box_with_subclass(): - return NameSpace(labels=[{ - 'featureId': - 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - "bbox": { - "top": 1099, - "left": 2010, - "height": 690, - "width": 591 - }, - 'classifications': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': 'test' - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - "bbox": { - "top": 1099, - "left": 2010, - "height": 690, - "width": 591 - }, - 'classifications': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': 'test' - }] - }], - expected=1.0) + return NameSpace( + labels=[ + { + "featureId": "ckppivl7p0006aeyj92cezr9d", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "bbox": { + "top": 1099, + "left": 2010, + "height": 690, + "width": 591, + }, + "classifications": [ + {"schemaId": "ckppid25v0000aeyjmxfwlc7t", "answer": "test"} + ], + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "bbox": { + "top": 1099, + "left": 2010, + "height": 690, + "width": 591, + }, + "classifications": [ + {"schemaId": "ckppid25v0000aeyjmxfwlc7t", "answer": "test"} + ], + } + ], + expected=1.0, + ) @pytest.fixture def test_box_with_wrong_subclass(): - return NameSpace(labels=[{ - 'featureId': - 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - "bbox": { - "top": 1099, - "left": 2010, - "height": 690, - "width": 591 - }, - 'classifications': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': 'test' - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - "bbox": { - "top": 1099, - "left": 2010, - "height": 690, - "width": 591 - }, - 'classifications': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': 'not_test' - }] - }], - expected=0.5, - expected_without_subclasses=1.0) + return NameSpace( + labels=[ + { + "featureId": "ckppivl7p0006aeyj92cezr9d", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "bbox": { + "top": 1099, + "left": 2010, + "height": 690, + "width": 591, + }, + "classifications": [ + {"schemaId": "ckppid25v0000aeyjmxfwlc7t", "answer": "test"} + ], + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "bbox": { + "top": 1099, + "left": 2010, + "height": 690, + "width": 591, + }, + "classifications": [ + { + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "answer": "not_test", + } + ], + } + ], + expected=0.5, + expected_without_subclasses=1.0, + ) @pytest.fixture def line_pair(): - return NameSpace(labels=[{ - 'featureId': 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "line": [{ - "x": 0, - "y": 100 - }, { - "x": 0, - "y": 0 - }], - }], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - "line": [{ - "x": 5, - "y": 95 - }, { - "x": 0, - "y": 0 - }], - }], - expected=0.9496975567603978) + return NameSpace( + labels=[ + { + "featureId": "ckppivl7p0006aeyj92cezr9d", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "line": [{"x": 0, "y": 100}, {"x": 0, "y": 0}], + } + ], + predictions=[ + { + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "line": [{"x": 5, "y": 95}, {"x": 0, "y": 0}], + } + ], + expected=0.9496975567603978, + ) @pytest.fixture def point_pair(): - return NameSpace(labels=[{ - 'featureId': 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "point": { - 'x': 0, - 'y': 0 - } - }], - predictions=[{ - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "point": { - 'x': 5, - 'y': 5 - } - }], - expected=0.879113232477017) + return NameSpace( + labels=[ + { + "featureId": "ckppivl7p0006aeyj92cezr9d", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "point": {"x": 0, "y": 0}, + } + ], + predictions=[ + { + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "point": {"x": 5, "y": 5}, + } + ], + expected=0.879113232477017, + ) @pytest.fixture def matching_ner(): - return NameSpace(labels=[{ - 'featureId': 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'format': "text.location", - 'data': { - "location": { - "start": 0, - "end": 10 - } - } - }], - predictions=[{ - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "location": { - "start": 0, - "end": 10 - } - }], - expected=1) + return NameSpace( + labels=[ + { + "featureId": "ckppivl7p0006aeyj92cezr9d", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "format": "text.location", + "data": {"location": {"start": 0, "end": 10}}, + } + ], + predictions=[ + { + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "location": {"start": 0, "end": 10}, + } + ], + expected=1, + ) @pytest.fixture def no_matching_ner(): - return NameSpace(labels=[{ - 'featureId': 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'format': "text.location", - 'data': { - "location": { - "start": 0, - "end": 5 - } - } - }], - predictions=[{ - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "location": { - "start": 5, - "end": 10 - } - }], - expected=0) + return NameSpace( + labels=[ + { + "featureId": "ckppivl7p0006aeyj92cezr9d", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "format": "text.location", + "data": {"location": {"start": 0, "end": 5}}, + } + ], + predictions=[ + { + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "location": {"start": 5, "end": 10}, + } + ], + expected=0, + ) @pytest.fixture def partial_matching_ner(): - return NameSpace(labels=[{ - 'featureId': 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'format': "text.location", - 'data': { - "location": { - "start": 0, - "end": 7 - } - } - }], - predictions=[{ - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "location": { - "start": 3, - "end": 5 - } - }], - expected=0.2857142857142857) + return NameSpace( + labels=[ + { + "featureId": "ckppivl7p0006aeyj92cezr9d", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "format": "text.location", + "data": {"location": {"start": 0, "end": 7}}, + } + ], + predictions=[ + { + "dataRow": {"id": "ckppihxc10005aeyjen11h7jh"}, + "uuid": "76e0dcea-fe46-43e5-95f5-a5e3f378520a", + "schemaId": "ckppid25v0000aeyjmxfwlc7t", + "location": {"start": 3, "end": 5}, + } + ], + expected=0.2857142857142857, + ) diff --git a/libs/labelbox/tests/data/metrics/iou/feature/conftest.py b/libs/labelbox/tests/data/metrics/iou/feature/conftest.py index c89d30056..c3b2a28e3 100644 --- a/libs/labelbox/tests/data/metrics/iou/feature/conftest.py +++ b/libs/labelbox/tests/data/metrics/iou/feature/conftest.py @@ -2,107 +2,140 @@ import pytest -from labelbox.data.annotation_types import ClassificationAnnotation, ObjectAnnotation +from labelbox.data.annotation_types import ( + ClassificationAnnotation, + ObjectAnnotation, +) from labelbox.data.annotation_types import Polygon, Point class NameSpace(SimpleNamespace): - def __init__(self, predictions, ground_truths, expected): - super(NameSpace, self).__init__(predictions=predictions, - ground_truths=ground_truths, - expected=expected) + super(NameSpace, self).__init__( + predictions=predictions, + ground_truths=ground_truths, + expected=expected, + ) @pytest.fixture def different_classes(): return [ - NameSpace(predictions=[ - ObjectAnnotation(name="cat", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])) - ], - ground_truths=[ - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])) - ], - expected={ - 'cat': 0, - 'dog': 0 - }) + NameSpace( + predictions=[ + ObjectAnnotation( + name="cat", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=10, y=0), + Point(x=10, y=10), + Point(x=0, y=10), + ] + ), + ) + ], + ground_truths=[ + ObjectAnnotation( + name="dog", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=10, y=0), + Point(x=10, y=10), + Point(x=0, y=10), + ] + ), + ) + ], + expected={"cat": 0, "dog": 0}, + ) ] @pytest.fixture def one_overlap_class(): return [ - NameSpace(predictions=[ - ObjectAnnotation(name="cat", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])), - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=5, y=0), - Point(x=5, y=5), - Point(x=0, y=5) - ])) - ], - ground_truths=[ - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])) - ], - expected={ - 'dog': 0.25, - 'cat': 0. - }), - NameSpace(predictions=[ - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=5, y=0), - Point(x=5, y=5), - Point(x=0, y=5) - ])) - ], - ground_truths=[ - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])), - ObjectAnnotation(name="cat", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])) - ], - expected={ - 'dog': 0.25, - 'cat': 0. - }) + NameSpace( + predictions=[ + ObjectAnnotation( + name="cat", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=10, y=0), + Point(x=10, y=10), + Point(x=0, y=10), + ] + ), + ), + ObjectAnnotation( + name="dog", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=5, y=0), + Point(x=5, y=5), + Point(x=0, y=5), + ] + ), + ), + ], + ground_truths=[ + ObjectAnnotation( + name="dog", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=10, y=0), + Point(x=10, y=10), + Point(x=0, y=10), + ] + ), + ) + ], + expected={"dog": 0.25, "cat": 0.0}, + ), + NameSpace( + predictions=[ + ObjectAnnotation( + name="dog", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=5, y=0), + Point(x=5, y=5), + Point(x=0, y=5), + ] + ), + ) + ], + ground_truths=[ + ObjectAnnotation( + name="dog", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=10, y=0), + Point(x=10, y=10), + Point(x=0, y=10), + ] + ), + ), + ObjectAnnotation( + name="cat", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=10, y=0), + Point(x=10, y=10), + Point(x=0, y=10), + ] + ), + ), + ], + expected={"dog": 0.25, "cat": 0.0}, + ), ] @@ -110,46 +143,60 @@ def one_overlap_class(): def empty_annotations(): return [ NameSpace(predictions=[], ground_truths=[], expected={}), - NameSpace(predictions=[], - ground_truths=[ - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])), - ObjectAnnotation(name="cat", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])) - ], - expected={ - 'dog': 0., - 'cat': 0. - }), - NameSpace(predictions=[ - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])), - ObjectAnnotation(name="cat", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])) - ], - ground_truths=[], - expected={ - 'dog': 0., - 'cat': 0. - }) + NameSpace( + predictions=[], + ground_truths=[ + ObjectAnnotation( + name="dog", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=10, y=0), + Point(x=10, y=10), + Point(x=0, y=10), + ] + ), + ), + ObjectAnnotation( + name="cat", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=10, y=0), + Point(x=10, y=10), + Point(x=0, y=10), + ] + ), + ), + ], + expected={"dog": 0.0, "cat": 0.0}, + ), + NameSpace( + predictions=[ + ObjectAnnotation( + name="dog", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=10, y=0), + Point(x=10, y=10), + Point(x=0, y=10), + ] + ), + ), + ObjectAnnotation( + name="cat", + value=Polygon( + points=[ + Point(x=0, y=0), + Point(x=10, y=0), + Point(x=10, y=10), + Point(x=0, y=10), + ] + ), + ), + ], + ground_truths=[], + expected={"dog": 0.0, "cat": 0.0}, + ), ] diff --git a/libs/labelbox/tests/data/metrics/iou/feature/test_feature_iou.py b/libs/labelbox/tests/data/metrics/iou/feature/test_feature_iou.py index 653e485d1..da324c51b 100644 --- a/libs/labelbox/tests/data/metrics/iou/feature/test_feature_iou.py +++ b/libs/labelbox/tests/data/metrics/iou/feature/test_feature_iou.py @@ -19,7 +19,8 @@ def check_iou(pair): assert len(one_metrics) one_metric = one_metrics[0] assert one_metric.value == sum(list(pair.expected.values())) / len( - pair.expected) + pair.expected + ) def test_different_classes(different_classes): diff --git a/libs/labelbox/tests/data/serialization/coco/test_coco.py b/libs/labelbox/tests/data/serialization/coco/test_coco.py index 0113b555d..a7c733ce5 100644 --- a/libs/labelbox/tests/data/serialization/coco/test_coco.py +++ b/libs/labelbox/tests/data/serialization/coco/test_coco.py @@ -7,9 +7,10 @@ def run_instances(tmpdir): - instance_json = json.load(open(Path(COCO_ASSETS_DIR, 'instances.json'))) - res = COCOConverter.deserialize_instances(instance_json, - Path(COCO_ASSETS_DIR, 'images')) + instance_json = json.load(open(Path(COCO_ASSETS_DIR, "instances.json"))) + res = COCOConverter.deserialize_instances( + instance_json, Path(COCO_ASSETS_DIR, "images") + ) back = COCOConverter.serialize_instances( res, Path(tmpdir), @@ -17,18 +18,21 @@ def run_instances(tmpdir): def test_rle_objects(tmpdir): - rle_json = json.load(open(Path(COCO_ASSETS_DIR, 'rle.json'))) - res = COCOConverter.deserialize_instances(rle_json, - Path(COCO_ASSETS_DIR, 'images')) + rle_json = json.load(open(Path(COCO_ASSETS_DIR, "rle.json"))) + res = COCOConverter.deserialize_instances( + rle_json, Path(COCO_ASSETS_DIR, "images") + ) back = COCOConverter.serialize_instances(res, tmpdir) def test_panoptic(tmpdir): - panoptic_json = json.load(open(Path(COCO_ASSETS_DIR, 'panoptic.json'))) + panoptic_json = json.load(open(Path(COCO_ASSETS_DIR, "panoptic.json"))) image_dir, mask_dir = [ - Path(COCO_ASSETS_DIR, dir_name) for dir_name in ['images', 'masks'] + Path(COCO_ASSETS_DIR, dir_name) for dir_name in ["images", "masks"] ] res = COCOConverter.deserialize_panoptic(panoptic_json, image_dir, mask_dir) - back = COCOConverter.serialize_panoptic(res, - Path(f'/{tmpdir}/images_panoptic'), - Path(f'/{tmpdir}/masks_panoptic')) + back = COCOConverter.serialize_panoptic( + res, + Path(f"/{tmpdir}/images_panoptic"), + Path(f"/{tmpdir}/masks_panoptic"), + ) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_checklist.py b/libs/labelbox/tests/data/serialization/ndjson/test_checklist.py index c4b47427a..0bc3c8924 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_checklist.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_checklist.py @@ -1,5 +1,9 @@ from labelbox.data.annotation_types.annotation import ClassificationAnnotation -from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnswer, Radio +from labelbox.data.annotation_types.classification.classification import ( + Checklist, + ClassificationAnswer, + Radio, +) from labelbox.data.annotation_types.data.text import TextData from labelbox.data.annotation_types.label import Label @@ -17,18 +21,16 @@ def test_serialization_min(): ClassificationAnnotation( name="checkbox_question_geo", value=Checklist( - answer=[ClassificationAnswer(name="first_answer")]), + answer=[ClassificationAnswer(name="first_answer")] + ), ) - ]) + ], + ) expected = { - 'name': 'checkbox_question_geo', - 'dataRow': { - 'id': 'bkj7z2q0b0000jx6x0q2q7q0d' - }, - 'answer': [{ - 'name': 'first_answer' - }] + "name": "checkbox_question_geo", + "dataRow": {"id": "bkj7z2q0b0000jx6x0q2q7q0d"}, + "answer": [{"name": "first_answer"}], } serialized = NDJsonConverter.serialize([label]) res = next(serialized) @@ -54,61 +56,76 @@ def test_serialization_with_classification(): ClassificationAnnotation( name="checkbox_question_geo", confidence=0.5, - value=Checklist(answer=[ - ClassificationAnswer( - name="first_answer", - confidence=0.1, - classifications=[ - ClassificationAnnotation( - name="sub_radio_question", - value=Radio(answer=ClassificationAnswer( - name="first_sub_radio_answer", - confidence=0.31))), - ClassificationAnnotation( - name="sub_chck_question", - value=Checklist(answer=[ - ClassificationAnswer( - name="second_subchk_answer", - confidence=0.41), - ClassificationAnswer( - name="third_subchk_answer", - confidence=0.42), - ],)) - ]), - ])) - ]) + value=Checklist( + answer=[ + ClassificationAnswer( + name="first_answer", + confidence=0.1, + classifications=[ + ClassificationAnnotation( + name="sub_radio_question", + value=Radio( + answer=ClassificationAnswer( + name="first_sub_radio_answer", + confidence=0.31, + ) + ), + ), + ClassificationAnnotation( + name="sub_chck_question", + value=Checklist( + answer=[ + ClassificationAnswer( + name="second_subchk_answer", + confidence=0.41, + ), + ClassificationAnswer( + name="third_subchk_answer", + confidence=0.42, + ), + ], + ), + ), + ], + ), + ] + ), + ) + ], + ) expected = { - 'confidence': - 0.5, - 'name': - 'checkbox_question_geo', - 'dataRow': { - 'id': 'bkj7z2q0b0000jx6x0q2q7q0d' - }, - 'answer': [{ - 'confidence': - 0.1, - 'name': - 'first_answer', - 'classifications': [{ - 'name': 'sub_radio_question', - 'answer': { - 'confidence': 0.31, - 'name': 'first_sub_radio_answer', - } - }, { - 'name': - 'sub_chck_question', - 'answer': [{ - 'confidence': 0.41, - 'name': 'second_subchk_answer', - }, { - 'confidence': 0.42, - 'name': 'third_subchk_answer', - }] - }] - }] + "confidence": 0.5, + "name": "checkbox_question_geo", + "dataRow": {"id": "bkj7z2q0b0000jx6x0q2q7q0d"}, + "answer": [ + { + "confidence": 0.1, + "name": "first_answer", + "classifications": [ + { + "name": "sub_radio_question", + "answer": { + "confidence": 0.31, + "name": "first_sub_radio_answer", + }, + }, + { + "name": "sub_chck_question", + "answer": [ + { + "confidence": 0.41, + "name": "second_subchk_answer", + }, + { + "confidence": 0.42, + "name": "third_subchk_answer", + }, + ], + }, + ], + } + ], } serialized = NDJsonConverter.serialize([label]) @@ -119,7 +136,9 @@ def test_serialization_with_classification(): deserialized = NDJsonConverter.deserialize([res]) res = next(deserialized) - assert label.model_dump(exclude_none=True) == label.model_dump(exclude_none=True) + assert label.model_dump(exclude_none=True) == label.model_dump( + exclude_none=True + ) def test_serialization_with_classification_double_nested(): @@ -133,66 +152,80 @@ def test_serialization_with_classification_double_nested(): ClassificationAnnotation( name="checkbox_question_geo", confidence=0.5, - value=Checklist(answer=[ - ClassificationAnswer( - name="first_answer", - confidence=0.1, - classifications=[ - ClassificationAnnotation( - name="sub_radio_question", - value=Radio(answer=ClassificationAnswer( - name="first_sub_radio_answer", - confidence=0.31, - classifications=[ - ClassificationAnnotation( - name="sub_chck_question", - value=Checklist(answer=[ - ClassificationAnswer( - name="second_subchk_answer", - confidence=0.41), - ClassificationAnswer( - name="third_subchk_answer", - confidence=0.42), - ],)) - ]))), - ]), - ])) - ]) + value=Checklist( + answer=[ + ClassificationAnswer( + name="first_answer", + confidence=0.1, + classifications=[ + ClassificationAnnotation( + name="sub_radio_question", + value=Radio( + answer=ClassificationAnswer( + name="first_sub_radio_answer", + confidence=0.31, + classifications=[ + ClassificationAnnotation( + name="sub_chck_question", + value=Checklist( + answer=[ + ClassificationAnswer( + name="second_subchk_answer", + confidence=0.41, + ), + ClassificationAnswer( + name="third_subchk_answer", + confidence=0.42, + ), + ], + ), + ) + ], + ) + ), + ), + ], + ), + ] + ), + ) + ], + ) expected = { - 'confidence': - 0.5, - 'name': - 'checkbox_question_geo', - 'dataRow': { - 'id': 'bkj7z2q0b0000jx6x0q2q7q0d' - }, - 'answer': [{ - 'confidence': - 0.1, - 'name': - 'first_answer', - 'classifications': [{ - 'name': 'sub_radio_question', - 'answer': { - 'confidence': - 0.31, - 'name': - 'first_sub_radio_answer', - 'classifications': [{ - 'name': - 'sub_chck_question', - 'answer': [{ - 'confidence': 0.41, - 'name': 'second_subchk_answer', - }, { - 'confidence': 0.42, - 'name': 'third_subchk_answer', - }] - }] - } - }] - }] + "confidence": 0.5, + "name": "checkbox_question_geo", + "dataRow": {"id": "bkj7z2q0b0000jx6x0q2q7q0d"}, + "answer": [ + { + "confidence": 0.1, + "name": "first_answer", + "classifications": [ + { + "name": "sub_radio_question", + "answer": { + "confidence": 0.31, + "name": "first_sub_radio_answer", + "classifications": [ + { + "name": "sub_chck_question", + "answer": [ + { + "confidence": 0.41, + "name": "second_subchk_answer", + }, + { + "confidence": 0.42, + "name": "third_subchk_answer", + }, + ], + } + ], + }, + } + ], + } + ], } serialized = NDJsonConverter.serialize([label]) res = next(serialized) @@ -203,7 +236,9 @@ def test_serialization_with_classification_double_nested(): deserialized = NDJsonConverter.deserialize([res]) res = next(deserialized) res.annotations[0].extra.pop("uuid") - assert label.model_dump(exclude_none=True) == label.model_dump(exclude_none=True) + assert label.model_dump(exclude_none=True) == label.model_dump( + exclude_none=True + ) def test_serialization_with_classification_double_nested_2(): @@ -216,62 +251,79 @@ def test_serialization_with_classification_double_nested_2(): annotations=[ ClassificationAnnotation( name="sub_radio_question", - value=Radio(answer=ClassificationAnswer( - name="first_sub_radio_answer", - confidence=0.31, - classifications=[ - ClassificationAnnotation( - name="sub_chck_question", - value=Checklist(answer=[ - ClassificationAnswer( - name="second_subchk_answer", - confidence=0.41, - classifications=[ - ClassificationAnnotation( - name="checkbox_question_geo", - value=Checklist(answer=[ - ClassificationAnswer( - name="first_answer", - confidence=0.1), - ])) - ]), - ClassificationAnswer(name="third_subchk_answer", - confidence=0.42), - ])) - ]))), - ]) + value=Radio( + answer=ClassificationAnswer( + name="first_sub_radio_answer", + confidence=0.31, + classifications=[ + ClassificationAnnotation( + name="sub_chck_question", + value=Checklist( + answer=[ + ClassificationAnswer( + name="second_subchk_answer", + confidence=0.41, + classifications=[ + ClassificationAnnotation( + name="checkbox_question_geo", + value=Checklist( + answer=[ + ClassificationAnswer( + name="first_answer", + confidence=0.1, + ), + ] + ), + ) + ], + ), + ClassificationAnswer( + name="third_subchk_answer", + confidence=0.42, + ), + ] + ), + ) + ], + ) + ), + ), + ], + ) expected = { - 'name': 'sub_radio_question', - 'answer': { - 'confidence': - 0.31, - 'name': - 'first_sub_radio_answer', - 'classifications': [{ - 'name': - 'sub_chck_question', - 'answer': [{ - 'confidence': - 0.41, - 'name': - 'second_subchk_answer', - 'classifications': [{ - 'name': 'checkbox_question_geo', - 'answer': [{ - 'confidence': 0.1, - 'name': 'first_answer', - }] - }] - }, { - 'confidence': 0.42, - 'name': 'third_subchk_answer', - }] - }] + "name": "sub_radio_question", + "answer": { + "confidence": 0.31, + "name": "first_sub_radio_answer", + "classifications": [ + { + "name": "sub_chck_question", + "answer": [ + { + "confidence": 0.41, + "name": "second_subchk_answer", + "classifications": [ + { + "name": "checkbox_question_geo", + "answer": [ + { + "confidence": 0.1, + "name": "first_answer", + } + ], + } + ], + }, + { + "confidence": 0.42, + "name": "third_subchk_answer", + }, + ], + } + ], }, - 'dataRow': { - 'id': 'bkj7z2q0b0000jx6x0q2q7q0d' - } + "dataRow": {"id": "bkj7z2q0b0000jx6x0q2q7q0d"}, } serialized = NDJsonConverter.serialize([label]) @@ -281,4 +333,6 @@ def test_serialization_with_classification_double_nested_2(): deserialized = NDJsonConverter.deserialize([res]) res = next(deserialized) - assert label.model_dump(exclude_none=True) == label.model_dump(exclude_none=True) + assert label.model_dump(exclude_none=True) == label.model_dump( + exclude_none=True + ) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_classification.py b/libs/labelbox/tests/data/serialization/ndjson/test_classification.py index 00a684b20..8dcb17f0b 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_classification.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_classification.py @@ -4,8 +4,9 @@ def test_classification(): - with open('tests/data/assets/ndjson/classification_import.json', - 'r') as file: + with open( + "tests/data/assets/ndjson/classification_import.json", "r" + ) as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) @@ -13,8 +14,9 @@ def test_classification(): def test_classification_with_name(): - with open('tests/data/assets/ndjson/classification_import_name_only.json', - 'r') as file: + with open( + "tests/data/assets/ndjson/classification_import_name_only.json", "r" + ) as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_conversation.py b/libs/labelbox/tests/data/serialization/ndjson/test_conversation.py index 4d2a0416c..f7da9181b 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_conversation.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_conversation.py @@ -4,99 +4,117 @@ import labelbox.types as lb_types from labelbox.data.serialization.ndjson.converter import NDJsonConverter -radio_ndjson = [{ - 'dataRow': { - 'globalKey': 'my_global_key' - }, - 'name': 'radio', - 'answer': { - 'name': 'first_radio_answer' - }, - 'messageId': '0' -}] +radio_ndjson = [ + { + "dataRow": {"globalKey": "my_global_key"}, + "name": "radio", + "answer": {"name": "first_radio_answer"}, + "messageId": "0", + } +] radio_label = [ lb_types.Label( - data=lb_types.ConversationData(global_key='my_global_key'), + data=lb_types.ConversationData(global_key="my_global_key"), annotations=[ lb_types.ClassificationAnnotation( - name='radio', - value=lb_types.Radio(answer=lb_types.ClassificationAnswer( - name="first_radio_answer")), - message_id="0") - ]) + name="radio", + value=lb_types.Radio( + answer=lb_types.ClassificationAnswer( + name="first_radio_answer" + ) + ), + message_id="0", + ) + ], + ) ] -checklist_ndjson = [{ - 'dataRow': { - 'globalKey': 'my_global_key' - }, - 'name': 'checklist', - 'answer': [ - { - 'name': 'first_checklist_answer' - }, - { - 'name': 'second_checklist_answer' - }, - ], - 'messageId': '2' -}] +checklist_ndjson = [ + { + "dataRow": {"globalKey": "my_global_key"}, + "name": "checklist", + "answer": [ + {"name": "first_checklist_answer"}, + {"name": "second_checklist_answer"}, + ], + "messageId": "2", + } +] checklist_label = [ - lb_types.Label(data=lb_types.ConversationData(global_key='my_global_key'), - annotations=[ - lb_types.ClassificationAnnotation( - name='checklist', - message_id="2", - value=lb_types.Checklist(answer=[ - lb_types.ClassificationAnswer( - name="first_checklist_answer"), - lb_types.ClassificationAnswer( - name="second_checklist_answer") - ])) - ]) + lb_types.Label( + data=lb_types.ConversationData(global_key="my_global_key"), + annotations=[ + lb_types.ClassificationAnnotation( + name="checklist", + message_id="2", + value=lb_types.Checklist( + answer=[ + lb_types.ClassificationAnswer( + name="first_checklist_answer" + ), + lb_types.ClassificationAnswer( + name="second_checklist_answer" + ), + ] + ), + ) + ], + ) ] -free_text_ndjson = [{ - 'dataRow': { - 'globalKey': 'my_global_key' - }, - 'name': 'free_text', - 'answer': 'sample text', - 'messageId': '0' -}] +free_text_ndjson = [ + { + "dataRow": {"globalKey": "my_global_key"}, + "name": "free_text", + "answer": "sample text", + "messageId": "0", + } +] free_text_label = [ - lb_types.Label(data=lb_types.ConversationData(global_key='my_global_key'), - annotations=[ - lb_types.ClassificationAnnotation( - name='free_text', - message_id="0", - value=lb_types.Text(answer="sample text")) - ]) + lb_types.Label( + data=lb_types.ConversationData(global_key="my_global_key"), + annotations=[ + lb_types.ClassificationAnnotation( + name="free_text", + message_id="0", + value=lb_types.Text(answer="sample text"), + ) + ], + ) ] @pytest.mark.parametrize( "label, ndjson", - [[radio_label, radio_ndjson], [checklist_label, checklist_ndjson], - [free_text_label, free_text_ndjson]]) + [ + [radio_label, radio_ndjson], + [checklist_label, checklist_ndjson], + [free_text_label, free_text_ndjson], + ], +) def test_message_based_radio_classification(label, ndjson): serialized_label = list(NDJsonConverter().serialize(label)) - serialized_label[0].pop('uuid') + serialized_label[0].pop("uuid") assert serialized_label == ndjson deserialized_label = list(NDJsonConverter().deserialize(ndjson)) - deserialized_label[0].annotations[0].extra.pop('uuid') - assert deserialized_label[0].model_dump(exclude_none=True) == label[0].model_dump(exclude_none=True) + deserialized_label[0].annotations[0].extra.pop("uuid") + assert deserialized_label[0].model_dump(exclude_none=True) == label[ + 0 + ].model_dump(exclude_none=True) -@pytest.mark.parametrize("filename", [ - "tests/data/assets/ndjson/conversation_entity_import.json", - "tests/data/assets/ndjson/conversation_entity_without_confidence_import.json" -]) +@pytest.mark.parametrize( + "filename", + [ + "tests/data/assets/ndjson/conversation_entity_import.json", + "tests/data/assets/ndjson/conversation_entity_without_confidence_import.json", + ], +) def test_conversation_entity_import(filename: str): - with open(filename, 'r') as file: + with open(filename, "r") as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) @@ -104,30 +122,34 @@ def test_conversation_entity_import(filename: str): def test_benchmark_reference_label_flag_enabled(): - label = lb_types.Label(data=lb_types.ConversationData(global_key='my_global_key'), - annotations=[ - lb_types.ClassificationAnnotation( - name='free_text', - message_id="0", - value=lb_types.Text(answer="sample text")) - ], - is_benchmark_reference=True - ) + label = lb_types.Label( + data=lb_types.ConversationData(global_key="my_global_key"), + annotations=[ + lb_types.ClassificationAnnotation( + name="free_text", + message_id="0", + value=lb_types.Text(answer="sample text"), + ) + ], + is_benchmark_reference=True, + ) res = list(NDJsonConverter.serialize([label])) assert res[0]["isBenchmarkReferenceLabel"] def test_benchmark_reference_label_flag_disabled(): - label = lb_types.Label(data=lb_types.ConversationData(global_key='my_global_key'), - annotations=[ - lb_types.ClassificationAnnotation( - name='free_text', - message_id="0", - value=lb_types.Text(answer="sample text")) - ], - is_benchmark_reference=False - ) + label = lb_types.Label( + data=lb_types.ConversationData(global_key="my_global_key"), + annotations=[ + lb_types.ClassificationAnnotation( + name="free_text", + message_id="0", + value=lb_types.Text(answer="sample text"), + ) + ], + is_benchmark_reference=False, + ) res = list(NDJsonConverter.serialize([label])) assert not res[0].get("isBenchmarkReferenceLabel") diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py b/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py index 186c75223..333c00250 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py @@ -2,35 +2,41 @@ import pytest import labelbox.types as lb_types from labelbox.data.serialization import NDJsonConverter -from labelbox.data.serialization.ndjson.objects import NDDicomSegments, NDDicomSegment, NDDicomLine +from labelbox.data.serialization.ndjson.objects import ( + NDDicomSegments, + NDDicomSegment, + NDDicomLine, +) + """ Data gen prompt test data """ prompt_text_annotation = lb_types.PromptClassificationAnnotation( - feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", - name="test", - value=lb_types.PromptText(answer="the answer to the text questions right here"), - ) - -prompt_text_ndjson = { - "answer": "the answer to the text questions right here", - "name": "test", - "schemaId": "ckrb1sfkn099c0y910wbo0p1a", - "dataRow": { - "id": "ckrb1sf1i1g7i0ybcdc6oc8ct" - }, - } + feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", + name="test", + value=lb_types.PromptText( + answer="the answer to the text questions right here" + ), +) + +prompt_text_ndjson = { + "answer": "the answer to the text questions right here", + "name": "test", + "schemaId": "ckrb1sfkn099c0y910wbo0p1a", + "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, +} data_gen_label = lb_types.Label( data={"uid": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, - annotations=[prompt_text_annotation] + annotations=[prompt_text_annotation], ) - + """ Prompt annotation test """ + def test_serialize_label(): serialized_label = next(NDJsonConverter().serialize([data_gen_label])) # Remove uuid field since this is a random value that can not be specified also meant for relationships @@ -39,17 +45,23 @@ def test_serialize_label(): def test_deserialize_label(): - deserialized_label = next(NDJsonConverter().deserialize([prompt_text_ndjson])) - if hasattr(deserialized_label.annotations[0], 'extra'): + deserialized_label = next( + NDJsonConverter().deserialize([prompt_text_ndjson]) + ) + if hasattr(deserialized_label.annotations[0], "extra"): # Extra fields are added to deserialized label by default need removed to match deserialized_label.annotations[0].extra = {} - assert deserialized_label.model_dump(exclude_none=True) == data_gen_label.model_dump(exclude_none=True) + assert deserialized_label.model_dump( + exclude_none=True + ) == data_gen_label.model_dump(exclude_none=True) def test_serialize_deserialize_label(): serialized = list(NDJsonConverter.serialize([data_gen_label])) deserialized = next(NDJsonConverter.deserialize(serialized)) - if hasattr(deserialized.annotations[0], 'extra'): + if hasattr(deserialized.annotations[0], "extra"): # Extra fields are added to deserialized label by default need removed to match deserialized.annotations[0].extra = {} - assert deserialized.model_dump(exclude_none=True) == data_gen_label.model_dump(exclude_none=True) + assert deserialized.model_dump( + exclude_none=True + ) == data_gen_label.model_dump(exclude_none=True) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_dicom.py b/libs/labelbox/tests/data/serialization/ndjson/test_dicom.py index e69c21bae..633214367 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_dicom.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_dicom.py @@ -3,127 +3,120 @@ import base64 import labelbox.types as lb_types from labelbox.data.serialization import NDJsonConverter -from labelbox.data.serialization.ndjson.objects import NDDicomSegments, NDDicomSegment, NDDicomLine +from labelbox.data.serialization.ndjson.objects import ( + NDDicomSegments, + NDDicomSegment, + NDDicomLine, +) + """ Polyline test data """ dicom_polyline_annotations = [ - lb_types.DICOMObjectAnnotation(uuid="78a8a027-9089-420c-8348-6099eb77e4aa", - name="dicom_polyline", - frame=2, - value=lb_types.Line(points=[ - lb_types.Point(x=680, y=100), - lb_types.Point(x=100, y=190), - lb_types.Point(x=190, y=220) - ]), - segment_index=0, - keyframe=True, - group_key=lb_types.GroupKey.AXIAL) + lb_types.DICOMObjectAnnotation( + uuid="78a8a027-9089-420c-8348-6099eb77e4aa", + name="dicom_polyline", + frame=2, + value=lb_types.Line( + points=[ + lb_types.Point(x=680, y=100), + lb_types.Point(x=100, y=190), + lb_types.Point(x=190, y=220), + ] + ), + segment_index=0, + keyframe=True, + group_key=lb_types.GroupKey.AXIAL, + ) ] -polyline_label = lb_types.Label(data=lb_types.DicomData(uid="test-uid"), - annotations=dicom_polyline_annotations) +polyline_label = lb_types.Label( + data=lb_types.DicomData(uid="test-uid"), + annotations=dicom_polyline_annotations, +) polyline_annotation_ndjson = { - 'classifications': [], - 'dataRow': { - 'id': 'test-uid' - }, - 'name': - 'dicom_polyline', - 'groupKey': - 'axial', - 'segments': [{ - 'keyframes': [{ - 'frame': 2, - 'line': [ - { - 'x': 680.0, - 'y': 100.0 - }, + "classifications": [], + "dataRow": {"id": "test-uid"}, + "name": "dicom_polyline", + "groupKey": "axial", + "segments": [ + { + "keyframes": [ { - 'x': 100.0, - 'y': 190.0 - }, - { - 'x': 190.0, - 'y': 220.0 - }, - ], - 'classifications': [], - }] - }], + "frame": 2, + "line": [ + {"x": 680.0, "y": 100.0}, + {"x": 100.0, "y": 190.0}, + {"x": 190.0, "y": 220.0}, + ], + "classifications": [], + } + ] + } + ], } polyline_with_global_key = lb_types.Label( data=lb_types.DicomData(global_key="test-global-key"), - annotations=dicom_polyline_annotations) + annotations=dicom_polyline_annotations, +) polyline_annotation_ndjson_with_global_key = copy(polyline_annotation_ndjson) -polyline_annotation_ndjson_with_global_key['dataRow'] = { - 'globalKey': 'test-global-key' +polyline_annotation_ndjson_with_global_key["dataRow"] = { + "globalKey": "test-global-key" } """ Video test data """ -instance_uri_1 = 'https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA' -instance_uri_5 = 'https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys1%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA' +instance_uri_1 = "https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA" +instance_uri_5 = "https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys1%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA" frames = [ lb_types.MaskFrame(index=1, instance_uri=instance_uri_1), - lb_types.MaskFrame(index=5, instance_uri=instance_uri_5) + lb_types.MaskFrame(index=5, instance_uri=instance_uri_5), ] instances = [ lb_types.MaskInstance(color_rgb=(0, 0, 255), name="mask1"), lb_types.MaskInstance(color_rgb=(0, 255, 0), name="mask2"), - lb_types.MaskInstance(color_rgb=(255, 0, 0), name="mask3") + lb_types.MaskInstance(color_rgb=(255, 0, 0), name="mask3"), ] -video_mask_annotation = lb_types.VideoMaskAnnotation(frames=frames, - instances=instances) +video_mask_annotation = lb_types.VideoMaskAnnotation( + frames=frames, instances=instances +) video_mask_annotation_ndjson = { - 'dataRow': { - 'id': 'test-uid' - }, - 'masks': { - 'frames': [{ - 'index': 1, - 'instanceURI': instance_uri_1 - }, { - 'index': 5, - 'instanceURI': instance_uri_5 - }], - 'instances': [ - { - 'colorRGB': (0, 0, 255), - 'name': 'mask1' - }, - { - 'colorRGB': (0, 255, 0), - 'name': 'mask2' - }, - { - 'colorRGB': (255, 0, 0), - 'name': 'mask3' - }, - ] + "dataRow": {"id": "test-uid"}, + "masks": { + "frames": [ + {"index": 1, "instanceURI": instance_uri_1}, + {"index": 5, "instanceURI": instance_uri_5}, + ], + "instances": [ + {"colorRGB": (0, 0, 255), "name": "mask1"}, + {"colorRGB": (0, 255, 0), "name": "mask2"}, + {"colorRGB": (255, 0, 0), "name": "mask3"}, + ], }, } video_mask_annotation_ndjson_with_global_key = copy( - video_mask_annotation_ndjson) -video_mask_annotation_ndjson_with_global_key['dataRow'] = { - 'globalKey': 'test-global-key' + video_mask_annotation_ndjson +) +video_mask_annotation_ndjson_with_global_key["dataRow"] = { + "globalKey": "test-global-key" } -video_mask_label = lb_types.Label(data=lb_types.VideoData(uid="test-uid"), - annotations=[video_mask_annotation]) +video_mask_label = lb_types.Label( + data=lb_types.VideoData(uid="test-uid"), annotations=[video_mask_annotation] +) video_mask_label_with_global_key = lb_types.Label( data=lb_types.VideoData(global_key="test-global-key"), - annotations=[video_mask_annotation]) + annotations=[video_mask_annotation], +) """ DICOM Mask test data """ @@ -132,30 +125,37 @@ name="dicom_mask", group_key=lb_types.GroupKey.AXIAL, frames=frames, - instances=instances) + instances=instances, +) -dicom_mask_label = lb_types.Label(data=lb_types.DicomData(uid="test-uid"), - annotations=[dicom_mask_annotation]) +dicom_mask_label = lb_types.Label( + data=lb_types.DicomData(uid="test-uid"), annotations=[dicom_mask_annotation] +) dicom_mask_label_with_global_key = lb_types.Label( data=lb_types.DicomData(global_key="test-global-key"), - annotations=[dicom_mask_annotation]) + annotations=[dicom_mask_annotation], +) dicom_mask_annotation_ndjson = copy(video_mask_annotation_ndjson) -dicom_mask_annotation_ndjson['groupKey'] = 'axial' +dicom_mask_annotation_ndjson["groupKey"] = "axial" dicom_mask_annotation_ndjson_with_global_key = copy( - dicom_mask_annotation_ndjson) -dicom_mask_annotation_ndjson_with_global_key['dataRow'] = { - 'globalKey': 'test-global-key' + dicom_mask_annotation_ndjson +) +dicom_mask_annotation_ndjson_with_global_key["dataRow"] = { + "globalKey": "test-global-key" } """ Tests """ labels = [ - polyline_label, polyline_with_global_key, dicom_mask_label, - dicom_mask_label_with_global_key, video_mask_label, - video_mask_label_with_global_key + polyline_label, + polyline_with_global_key, + dicom_mask_label, + dicom_mask_label_with_global_key, + video_mask_label, + video_mask_label_with_global_key, ] ndjsons = [ polyline_annotation_ndjson, @@ -175,32 +175,31 @@ def test_deserialize_nd_dicom_segments(): assert isinstance(nd_dicom_segments.segments[0].keyframes[0], NDDicomLine) -@pytest.mark.parametrize('label, ndjson', labels_ndjsons) +@pytest.mark.parametrize("label, ndjson", labels_ndjsons) def test_serialize_label(label, ndjson): serialized_label = next(NDJsonConverter().serialize([label])) if "uuid" in serialized_label: - serialized_label.pop('uuid') + serialized_label.pop("uuid") assert serialized_label == ndjson -@pytest.mark.parametrize('label, ndjson', labels_ndjsons) +@pytest.mark.parametrize("label, ndjson", labels_ndjsons) def test_deserialize_label(label, ndjson): deserialized_label = next(NDJsonConverter().deserialize([ndjson])) - if hasattr(deserialized_label.annotations[0], 'extra'): + if hasattr(deserialized_label.annotations[0], "extra"): deserialized_label.annotations[0].extra = {} for i, annotation in enumerate(deserialized_label.annotations): if hasattr(annotation, "frames"): assert annotation.frames == label.annotations[i].frames if hasattr(annotation, "value"): assert annotation.value == label.annotations[i].value - -@pytest.mark.parametrize('label', labels) +@pytest.mark.parametrize("label", labels) def test_serialize_deserialize_label(label): serialized = list(NDJsonConverter.serialize([label])) deserialized = list(NDJsonConverter.deserialize(serialized)) - if hasattr(deserialized[0].annotations[0], 'extra'): + if hasattr(deserialized[0].annotations[0], "extra"): deserialized[0].annotations[0].extra = {} for i, annotation in enumerate(deserialized[0].annotations): if hasattr(annotation, "frames"): diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_document.py b/libs/labelbox/tests/data/serialization/ndjson/test_document.py index cdfbbbb88..5fe6a9789 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_document.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_document.py @@ -8,26 +8,30 @@ start=lb_types.Point(x=42.799, y=86.498), # Top left end=lb_types.Point(x=141.911, y=303.195), # Bottom right page=1, - unit=lb_types.RectangleUnit.POINTS)) + unit=lb_types.RectangleUnit.POINTS, + ), +) bbox_labels = [ - lb_types.Label(data=lb_types.DocumentData(global_key='test-global-key'), - annotations=[bbox_annotation]) + lb_types.Label( + data=lb_types.DocumentData(global_key="test-global-key"), + annotations=[bbox_annotation], + ) +] +bbox_ndjson = [ + { + "bbox": { + "height": 216.697, + "left": 42.799, + "top": 86.498, + "width": 99.112, + }, + "classifications": [], + "dataRow": {"globalKey": "test-global-key"}, + "name": "bounding_box", + "page": 1, + "unit": "POINTS", + } ] -bbox_ndjson = [{ - 'bbox': { - 'height': 216.697, - 'left': 42.799, - 'top': 86.498, - 'width': 99.112, - }, - 'classifications': [], - 'dataRow': { - 'globalKey': 'test-global-key' - }, - 'name': 'bounding_box', - 'page': 1, - 'unit': 'POINTS' -}] def round_dict(data): @@ -47,7 +51,7 @@ def test_pdf(): """ Tests a pdf file with bbox annotations only """ - with open('tests/data/assets/ndjson/pdf_import.json', 'r') as f: + with open("tests/data/assets/ndjson/pdf_import.json", "r") as f: data = json.load(f) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) @@ -59,7 +63,7 @@ def test_pdf_with_name_only(): """ Tests a pdf file with bbox annotations only """ - with open('tests/data/assets/ndjson/pdf_import_name_only.json', 'r') as f: + with open("tests/data/assets/ndjson/pdf_import_name_only.json", "r") as f: data = json.load(f) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) @@ -69,12 +73,18 @@ def test_pdf_with_name_only(): def test_pdf_bbox_serialize(): serialized = list(NDJsonConverter.serialize(bbox_labels)) - serialized[0].pop('uuid') + serialized[0].pop("uuid") assert serialized == bbox_ndjson def test_pdf_bbox_deserialize(): deserialized = list(NDJsonConverter.deserialize(bbox_ndjson)) deserialized[0].annotations[0].extra = {} - assert deserialized[0].annotations[0].value == bbox_labels[0].annotations[0].value - assert deserialized[0].annotations[0].name == bbox_labels[0].annotations[0].name \ No newline at end of file + assert ( + deserialized[0].annotations[0].value + == bbox_labels[0].annotations[0].value + ) + assert ( + deserialized[0].annotations[0].name + == bbox_labels[0].annotations[0].name + ) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_export_video_objects.py b/libs/labelbox/tests/data/serialization/ndjson/test_export_video_objects.py index c85b48234..4adcd9935 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_export_video_objects.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_export_video_objects.py @@ -6,588 +6,580 @@ def video_bbox_label(): return Label( - uid='cl1z52xwh00050fhcmfgczqvn', + uid="cl1z52xwh00050fhcmfgczqvn", data=VideoData( uid="cklr9mr4m5iao0rb6cvxu4qbn", file_path=None, frames=None, - url= - "https://storage.labelbox.com/ckcz6bubudyfi0855o1dt1g9s%2F26403a22-604a-a38c-eeff-c2ed481fb40a-cat.mp4?Expires=1651677421050&KeyName=labelbox-assets-key-3&Signature=vF7gMyfHzgZdfbB8BHgd88Ws-Ms" + url="https://storage.labelbox.com/ckcz6bubudyfi0855o1dt1g9s%2F26403a22-604a-a38c-eeff-c2ed481fb40a-cat.mp4?Expires=1651677421050&KeyName=labelbox-assets-key-3&Signature=vF7gMyfHzgZdfbB8BHgd88Ws-Ms", ), annotations=[ - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=46.0), - end=Point(extra={}, - x=454.0, - y=295.0)), - classifications=[], - frame=1, - keyframe=True), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=42.5), - end=Point(extra={}, - x=427.25, - y=308.25)), - classifications=[], - frame=2, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=39.0), - end=Point(extra={}, - x=400.5, - y=321.5)), - classifications=[], - frame=3, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=35.5), - end=Point(extra={}, - x=373.75, - y=334.75)), - classifications=[], - frame=4, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=32.0), - end=Point(extra={}, - x=347.0, - y=348.0)), - classifications=[], - frame=5, - keyframe=True), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=132.0), - end=Point(extra={}, - x=283.0, - y=348.0)), - classifications=[], - frame=9, - keyframe=True), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=122.333), - end=Point(extra={}, - x=295.5, - y=348.0)), - classifications=[], - frame=10, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=112.667), - end=Point(extra={}, - x=308.0, - y=348.0)), - classifications=[], - frame=11, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=103.0), - end=Point(extra={}, - x=320.5, - y=348.0)), - classifications=[], - frame=12, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=93.333), - end=Point(extra={}, - x=333.0, - y=348.0)), - classifications=[], - frame=13, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=83.667), - end=Point(extra={}, - x=345.5, - y=348.0)), - classifications=[], - frame=14, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=74.0), - end=Point(extra={}, - x=358.0, - y=348.0)), - classifications=[], - frame=15, - keyframe=True), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=66.833), - end=Point(extra={}, - x=387.333, - y=348.0)), - classifications=[], - frame=16, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=59.667), - end=Point(extra={}, - x=416.667, - y=348.0)), - classifications=[], - frame=17, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=52.5), - end=Point(extra={}, - x=446.0, - y=348.0)), - classifications=[], - frame=18, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=45.333), - end=Point(extra={}, - x=475.333, - y=348.0)), - classifications=[], - frame=19, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=38.167), - end=Point(extra={}, - x=504.667, - y=348.0)), - classifications=[], - frame=20, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=31.0), - end=Point(extra={}, - x=534.0, - y=348.0)), - classifications=[], - frame=21, - keyframe=True), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=29.5), - end=Point(extra={}, - x=543.0, - y=348.0)), - classifications=[], - frame=22, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=28.0), - end=Point(extra={}, - x=552.0, - y=348.0)), - classifications=[], - frame=23, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=26.5), - end=Point(extra={}, - x=561.0, - y=348.0)), - classifications=[], - frame=24, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=25.0), - end=Point(extra={}, - x=570.0, - y=348.0)), - classifications=[], - frame=25, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=23.5), - end=Point(extra={}, - x=579.0, - y=348.0)), - classifications=[], - frame=26, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=22.0), - end=Point(extra={}, - x=588.0, - y=348.0)), - classifications=[], - frame=27, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=20.5), - end=Point(extra={}, - x=597.0, - y=348.0)), - classifications=[], - frame=28, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=19.0), - end=Point(extra={}, - x=606.0, - y=348.0)), - classifications=[], - frame=29, - keyframe=True) - ], - extra={ - 'Created By': - 'jtso@labelbox.com', - 'Project Name': - 'Pictor Video', - 'Created At': - '2022-04-14T15:11:19.000Z', - 'Updated At': - '2022-04-14T15:11:21.064Z', - 'Seconds to Label': - 0.0, - 'Agreement': - -1.0, - 'Benchmark Agreement': - -1.0, - 'Benchmark ID': - None, - 'Dataset Name': - 'cat', - 'Reviews': [], - 'View Label': - 'https://editor.labelbox.com?project=ckz38nsfd0lzq109bhq73est1&label=cl1z52xwh00050fhcmfgczqvn', - 'Has Open Issues': - 0.0, - 'Skipped': - False, - 'media_type': - 'video', - 'Data Split': - None - }) - - -def video_serialized_bbox_label(): - return { - 'uuid': - 'b24e672b-8f79-4d96-bf5e-b552ca0820d5', - 'dataRow': { - 'id': 'cklr9mr4m5iao0rb6cvxu4qbn' - }, - 'schemaId': - 'ckz38ofop0mci0z9i9w3aa9o4', - 'name': - 'bbox toy', - 'classifications': [], - 'segments': [{ - 'keyframes': [{ - 'frame': 1, - 'bbox': { - 'top': 46.0, - 'left': 70.0, - 'height': 249.0, - 'width': 384.0 + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=46.0), + end=Point(extra={}, x=454.0, y=295.0), + ), + classifications=[], + frame=1, + keyframe=True, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=42.5), + end=Point(extra={}, x=427.25, y=308.25), + ), + classifications=[], + frame=2, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=39.0), + end=Point(extra={}, x=400.5, y=321.5), + ), + classifications=[], + frame=3, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=35.5), + end=Point(extra={}, x=373.75, y=334.75), + ), + classifications=[], + frame=4, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=32.0), + end=Point(extra={}, x=347.0, y=348.0), + ), + classifications=[], + frame=5, + keyframe=True, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=132.0), + end=Point(extra={}, x=283.0, y=348.0), + ), + classifications=[], + frame=9, + keyframe=True, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=122.333), + end=Point(extra={}, x=295.5, y=348.0), + ), + classifications=[], + frame=10, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=112.667), + end=Point(extra={}, x=308.0, y=348.0), + ), + classifications=[], + frame=11, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=103.0), + end=Point(extra={}, x=320.5, y=348.0), + ), + classifications=[], + frame=12, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=93.333), + end=Point(extra={}, x=333.0, y=348.0), + ), + classifications=[], + frame=13, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=83.667), + end=Point(extra={}, x=345.5, y=348.0), + ), + classifications=[], + frame=14, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", }, - 'classifications': [] - }, { - 'frame': 5, - 'bbox': { - 'top': 32.0, - 'left': 70.0, - 'height': 316.0, - 'width': 277.0 + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=74.0), + end=Point(extra={}, x=358.0, y=348.0), + ), + classifications=[], + frame=15, + keyframe=True, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", }, - 'classifications': [] - }] - }, { - 'keyframes': [{ - 'frame': 9, - 'bbox': { - 'top': 132.0, - 'left': 70.0, - 'height': 216.0, - 'width': 213.0 + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=66.833), + end=Point(extra={}, x=387.333, y=348.0), + ), + classifications=[], + frame=16, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", }, - 'classifications': [] - }, { - 'frame': 15, - 'bbox': { - 'top': 74.0, - 'left': 70.0, - 'height': 274.0, - 'width': 288.0 + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=59.667), + end=Point(extra={}, x=416.667, y=348.0), + ), + classifications=[], + frame=17, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", }, - 'classifications': [] - }, { - 'frame': 21, - 'bbox': { - 'top': 31.0, - 'left': 70.0, - 'height': 317.0, - 'width': 464.0 + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=52.5), + end=Point(extra={}, x=446.0, y=348.0), + ), + classifications=[], + frame=18, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", }, - 'classifications': [] - }, { - 'frame': 29, - 'bbox': { - 'top': 19.0, - 'left': 70.0, - 'height': 329.0, - 'width': 536.0 + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=45.333), + end=Point(extra={}, x=475.333, y=348.0), + ), + classifications=[], + frame=19, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", }, - 'classifications': [] - }] - }] + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=38.167), + end=Point(extra={}, x=504.667, y=348.0), + ), + classifications=[], + frame=20, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=31.0), + end=Point(extra={}, x=534.0, y=348.0), + ), + classifications=[], + frame=21, + keyframe=True, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=29.5), + end=Point(extra={}, x=543.0, y=348.0), + ), + classifications=[], + frame=22, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=28.0), + end=Point(extra={}, x=552.0, y=348.0), + ), + classifications=[], + frame=23, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=26.5), + end=Point(extra={}, x=561.0, y=348.0), + ), + classifications=[], + frame=24, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=25.0), + end=Point(extra={}, x=570.0, y=348.0), + ), + classifications=[], + frame=25, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=23.5), + end=Point(extra={}, x=579.0, y=348.0), + ), + classifications=[], + frame=26, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=22.0), + end=Point(extra={}, x=588.0, y=348.0), + ), + classifications=[], + frame=27, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=20.5), + end=Point(extra={}, x=597.0, y=348.0), + ), + classifications=[], + frame=28, + keyframe=False, + ), + VideoObjectAnnotation( + name="bbox toy", + feature_schema_id="ckz38ofop0mci0z9i9w3aa9o4", + extra={ + "value": "bbox_toy", + "instanceURI": None, + "color": "#1CE6FF", + "feature_id": "cl1z52xw700000fhcayaqy0ev", + }, + value=Rectangle( + extra={}, + start=Point(extra={}, x=70.0, y=19.0), + end=Point(extra={}, x=606.0, y=348.0), + ), + classifications=[], + frame=29, + keyframe=True, + ), + ], + extra={ + "Created By": "jtso@labelbox.com", + "Project Name": "Pictor Video", + "Created At": "2022-04-14T15:11:19.000Z", + "Updated At": "2022-04-14T15:11:21.064Z", + "Seconds to Label": 0.0, + "Agreement": -1.0, + "Benchmark Agreement": -1.0, + "Benchmark ID": None, + "Dataset Name": "cat", + "Reviews": [], + "View Label": "https://editor.labelbox.com?project=ckz38nsfd0lzq109bhq73est1&label=cl1z52xwh00050fhcmfgczqvn", + "Has Open Issues": 0.0, + "Skipped": False, + "media_type": "video", + "Data Split": None, + }, + ) + + +def video_serialized_bbox_label(): + return { + "uuid": "b24e672b-8f79-4d96-bf5e-b552ca0820d5", + "dataRow": {"id": "cklr9mr4m5iao0rb6cvxu4qbn"}, + "schemaId": "ckz38ofop0mci0z9i9w3aa9o4", + "name": "bbox toy", + "classifications": [], + "segments": [ + { + "keyframes": [ + { + "frame": 1, + "bbox": { + "top": 46.0, + "left": 70.0, + "height": 249.0, + "width": 384.0, + }, + "classifications": [], + }, + { + "frame": 5, + "bbox": { + "top": 32.0, + "left": 70.0, + "height": 316.0, + "width": 277.0, + }, + "classifications": [], + }, + ] + }, + { + "keyframes": [ + { + "frame": 9, + "bbox": { + "top": 132.0, + "left": 70.0, + "height": 216.0, + "width": 213.0, + }, + "classifications": [], + }, + { + "frame": 15, + "bbox": { + "top": 74.0, + "left": 70.0, + "height": 274.0, + "width": 288.0, + }, + "classifications": [], + }, + { + "frame": 21, + "bbox": { + "top": 31.0, + "left": 70.0, + "height": 317.0, + "width": 464.0, + }, + "classifications": [], + }, + { + "frame": 29, + "bbox": { + "top": 19.0, + "left": 70.0, + "height": 329.0, + "width": 536.0, + }, + "classifications": [], + }, + ] + }, + ], } @@ -603,9 +595,9 @@ def test_serialize_video_objects(): if key != "uuid": assert label[key] == manual_label[key] - assert len(label['segments']) == 2 - assert len(label['segments'][0]['keyframes']) == 2 - assert len(label['segments'][1]['keyframes']) == 4 + assert len(label["segments"]) == 2 + assert len(label["segments"][0]["keyframes"]) == 2 + assert len(label["segments"][1]["keyframes"]) == 4 # #converts back only the keyframes. should be the sum of all prev segments deserialized_labels = NDJsonConverter.deserialize([label]) @@ -618,7 +610,7 @@ def test_confidence_is_ignored(): serialized_labels = NDJsonConverter.serialize([label]) label = next(serialized_labels) label["confidence"] = 0.453 - label['segments'][0]["confidence"] = 0.453 + label["segments"][0]["confidence"] = 0.453 deserialized_labels = NDJsonConverter.deserialize([label]) label = next(deserialized_labels) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_free_text.py b/libs/labelbox/tests/data/serialization/ndjson/test_free_text.py index aaa84953a..84c017497 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_free_text.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_free_text.py @@ -1,5 +1,10 @@ from labelbox.data.annotation_types.annotation import ClassificationAnnotation -from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnswer, Radio, Text +from labelbox.data.annotation_types.classification.classification import ( + Checklist, + ClassificationAnswer, + Radio, + Text, +) from labelbox.data.annotation_types.data.text import TextData from labelbox.data.annotation_types.label import Label @@ -7,24 +12,27 @@ def test_serialization(): - label = Label(uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( - uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", - ), - annotations=[ - ClassificationAnnotation(name="free_text_annotation", - value=Text(confidence=0.5, - answer="text_answer")) - ]) + label = Label( + uid="ckj7z2q0b0000jx6x0q2q7q0d", + data=TextData( + uid="bkj7z2q0b0000jx6x0q2q7q0d", + text="This is a test", + ), + annotations=[ + ClassificationAnnotation( + name="free_text_annotation", + value=Text(confidence=0.5, answer="text_answer"), + ) + ], + ) serialized = NDJsonConverter.serialize([label]) res = next(serialized) - assert res['confidence'] == 0.5 - assert res['name'] == "free_text_annotation" - assert res['answer'] == "text_answer" - assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d" + assert res["confidence"] == 0.5 + assert res["name"] == "free_text_annotation" + assert res["answer"] == "text_answer" + assert res["dataRow"]["id"] == "bkj7z2q0b0000jx6x0q2q7q0d" deserialized = NDJsonConverter.deserialize([res]) res = next(deserialized) @@ -47,44 +55,53 @@ def test_nested_serialization(): annotations=[ ClassificationAnnotation( name="nested test", - value=Checklist(answer=[ - ClassificationAnswer( - name="first_answer", - confidence=0.9, - classifications=[ - ClassificationAnnotation( - name="sub_radio_question", - value=Radio(answer=ClassificationAnswer( - name="first_sub_radio_answer", - confidence=0.8, - classifications=[ - ClassificationAnnotation( - name="nested answer", - value=Text( - answer="nested answer", - confidence=0.7, - )) - ]))) - ]) - ]), + value=Checklist( + answer=[ + ClassificationAnswer( + name="first_answer", + confidence=0.9, + classifications=[ + ClassificationAnnotation( + name="sub_radio_question", + value=Radio( + answer=ClassificationAnswer( + name="first_sub_radio_answer", + confidence=0.8, + classifications=[ + ClassificationAnnotation( + name="nested answer", + value=Text( + answer="nested answer", + confidence=0.7, + ), + ) + ], + ) + ), + ) + ], + ) + ] + ), ) - ]) + ], + ) serialized = NDJsonConverter.serialize([label]) res = next(serialized) - assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d" - answer = res['answer'][0] - assert answer['confidence'] == 0.9 - assert answer['name'] == "first_answer" - classification = answer['classifications'][0] - nested_classification_answer = classification['answer'] - assert nested_classification_answer['confidence'] == 0.8 - assert nested_classification_answer['name'] == "first_sub_radio_answer" - sub_classification = nested_classification_answer['classifications'][0] - assert sub_classification['name'] == "nested answer" - assert sub_classification['answer'] == "nested answer" - assert sub_classification['confidence'] == 0.7 + assert res["dataRow"]["id"] == "bkj7z2q0b0000jx6x0q2q7q0d" + answer = res["answer"][0] + assert answer["confidence"] == 0.9 + assert answer["name"] == "first_answer" + classification = answer["classifications"][0] + nested_classification_answer = classification["answer"] + assert nested_classification_answer["confidence"] == 0.8 + assert nested_classification_answer["name"] == "first_sub_radio_answer" + sub_classification = nested_classification_answer["classifications"][0] + assert sub_classification["name"] == "nested answer" + assert sub_classification["answer"] == "nested answer" + assert sub_classification["confidence"] == 0.7 deserialized = NDJsonConverter.deserialize([res]) res = next(deserialized) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_global_key.py b/libs/labelbox/tests/data/serialization/ndjson/test_global_key.py index 6de2dcc51..2b3fa7f8c 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_global_key.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_global_key.py @@ -20,15 +20,18 @@ def round_dict(data): return data -@pytest.mark.parametrize('filename', [ - 'tests/data/assets/ndjson/classification_import_global_key.json', - 'tests/data/assets/ndjson/metric_import_global_key.json', - 'tests/data/assets/ndjson/polyline_import_global_key.json', - 'tests/data/assets/ndjson/text_entity_import_global_key.json', - 'tests/data/assets/ndjson/conversation_entity_import_global_key.json', -]) +@pytest.mark.parametrize( + "filename", + [ + "tests/data/assets/ndjson/classification_import_global_key.json", + "tests/data/assets/ndjson/metric_import_global_key.json", + "tests/data/assets/ndjson/polyline_import_global_key.json", + "tests/data/assets/ndjson/text_entity_import_global_key.json", + "tests/data/assets/ndjson/conversation_entity_import_global_key.json", + ], +) def test_many_types(filename: str): - with open(filename, 'r') as f: + with open(filename, "r") as f: data = json.load(f) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) @@ -37,19 +40,20 @@ def test_many_types(filename: str): def test_image(): - with open('tests/data/assets/ndjson/image_import_global_key.json', - 'r') as f: + with open( + "tests/data/assets/ndjson/image_import_global_key.json", "r" + ) as f: data = json.load(f) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) for r in res: - r.pop('classifications', None) + r.pop("classifications", None) assert [round_dict(x) for x in res] == [round_dict(x) for x in data] f.close() def test_pdf(): - with open('tests/data/assets/ndjson/pdf_import_global_key.json', 'r') as f: + with open("tests/data/assets/ndjson/pdf_import_global_key.json", "r") as f: data = json.load(f) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) @@ -58,8 +62,9 @@ def test_pdf(): def test_video(): - with open('tests/data/assets/ndjson/video_import_global_key.json', - 'r') as f: + with open( + "tests/data/assets/ndjson/video_import_global_key.json", "r" + ) as f: data = json.load(f) res = list(NDJsonConverter.deserialize(data)) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_image.py b/libs/labelbox/tests/data/serialization/ndjson/test_image.py index e36ce6f50..1729e1f46 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_image.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_image.py @@ -3,7 +3,13 @@ import cv2 from labelbox.data.serialization.ndjson.converter import NDJsonConverter -from labelbox.data.annotation_types import Mask, Label, ObjectAnnotation, ImageData, MaskData +from labelbox.data.annotation_types import ( + Mask, + Label, + ObjectAnnotation, + ImageData, + MaskData, +) def round_dict(data): @@ -20,61 +26,56 @@ def round_dict(data): def test_image(): - with open('tests/data/assets/ndjson/image_import.json', 'r') as file: + with open("tests/data/assets/ndjson/image_import.json", "r") as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) for r in res: - r.pop('classifications', None) + r.pop("classifications", None) assert [round_dict(x) for x in res] == [round_dict(x) for x in data] def test_image_with_name_only(): - with open('tests/data/assets/ndjson/image_import_name_only.json', - 'r') as file: + with open( + "tests/data/assets/ndjson/image_import_name_only.json", "r" + ) as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) for r in res: - r.pop('classifications', None) + r.pop("classifications", None) assert [round_dict(x) for x in res] == [round_dict(x) for x in data] def test_mask(): - data = [{ - "uuid": "b862c586-8614-483c-b5e6-82810f70cac0", - "schemaId": "ckrazcueb16og0z6609jj7y3y", - "dataRow": { - "id": "ckrazctum0z8a0ybc0b0o0g0v" + data = [ + { + "uuid": "b862c586-8614-483c-b5e6-82810f70cac0", + "schemaId": "ckrazcueb16og0z6609jj7y3y", + "dataRow": {"id": "ckrazctum0z8a0ybc0b0o0g0v"}, + "mask": { + "png": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAAAAACoWZBhAAAAMklEQVR4nD3MuQ3AQADDMOqQ/Vd2ijytaSiZLAcYuyLEYYYl9cvrlGftTHvsYl+u/3EDv0QLI8Z7FlwAAAAASUVORK5CYII=" + }, + "confidence": 0.8, + "customMetrics": [{"name": "customMetric1", "value": 0.4}], }, - "mask": { - "png": - "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAAAAACoWZBhAAAAMklEQVR4nD3MuQ3AQADDMOqQ/Vd2ijytaSiZLAcYuyLEYYYl9cvrlGftTHvsYl+u/3EDv0QLI8Z7FlwAAAAASUVORK5CYII=" - }, - "confidence": 0.8, - "customMetrics": [{ - "name": "customMetric1", - "value": 0.4 - }], - }, { - "uuid": "751fc725-f7b6-48ed-89b0-dd7d94d08af6", - "schemaId": "ckrazcuec16ok0z66f956apb7", - "dataRow": { - "id": "ckrazctum0z8a0ybc0b0o0g0v" + { + "uuid": "751fc725-f7b6-48ed-89b0-dd7d94d08af6", + "schemaId": "ckrazcuec16ok0z66f956apb7", + "dataRow": {"id": "ckrazctum0z8a0ybc0b0o0g0v"}, + "mask": { + "instanceURI": "https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY", + "colorRGB": [255, 0, 0], + }, }, - "mask": { - "instanceURI": - "https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY", - "colorRGB": [255, 0, 0] - } - }] + ] res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) for r in res: - r.pop('classifications', None) + r.pop("classifications", None) assert [round_dict(x) for x in res] == [round_dict(x) for x in data] @@ -83,22 +84,24 @@ def test_mask_from_arr(): mask_arr = np.round(np.zeros((32, 32))).astype(np.uint8) mask_arr = cv2.rectangle(mask_arr, (5, 5), (10, 10), (1, 1), -1) - label = Label(annotations=[ - ObjectAnnotation(feature_schema_id="1" * 25, - value=Mask(mask=MaskData.from_2D_arr(arr=mask_arr), - color=(1, 1, 1))) - ], - data=ImageData(uid="0" * 25)) + label = Label( + annotations=[ + ObjectAnnotation( + feature_schema_id="1" * 25, + value=Mask( + mask=MaskData.from_2D_arr(arr=mask_arr), color=(1, 1, 1) + ), + ) + ], + data=ImageData(uid="0" * 25), + ) res = next(NDJsonConverter.serialize([label])) res.pop("uuid") assert res == { "classifications": [], "schemaId": "1" * 25, - "dataRow": { - "id": "0" * 25 - }, + "dataRow": {"id": "0" * 25}, "mask": { - "png": - "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAAAAABWESUoAAAAHklEQVR4nGNgGAKAEYn8j00BEyETBoOCUTAKhhwAAJW+AQwvpePVAAAAAElFTkSuQmCC" - } + "png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAAAAABWESUoAAAAHklEQVR4nGNgGAKAEYn8j00BEyETBoOCUTAKhhwAAJW+AQwvpePVAAAAAElFTkSuQmCC" + }, } diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_metric.py b/libs/labelbox/tests/data/serialization/ndjson/test_metric.py index 6508b73af..45c5c67bf 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_metric.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_metric.py @@ -4,7 +4,7 @@ def test_metric(): - with open('tests/data/assets/ndjson/metric_import.json', 'r') as file: + with open("tests/data/assets/ndjson/metric_import.json", "r") as file: data = json.load(file) label_list = list(NDJsonConverter.deserialize(data)) @@ -13,22 +13,26 @@ def test_metric(): def test_custom_scalar_metric(): - with open('tests/data/assets/ndjson/custom_scalar_import.json', - 'r') as file: + with open( + "tests/data/assets/ndjson/custom_scalar_import.json", "r" + ) as file: data = json.load(file) label_list = list(NDJsonConverter.deserialize(data)) reserialized = list(NDJsonConverter.serialize(label_list)) - assert json.dumps(reserialized, - sort_keys=True) == json.dumps(data, sort_keys=True) + assert json.dumps(reserialized, sort_keys=True) == json.dumps( + data, sort_keys=True + ) def test_custom_confusion_matrix_metric(): - with open('tests/data/assets/ndjson/custom_confusion_matrix_import.json', - 'r') as file: + with open( + "tests/data/assets/ndjson/custom_confusion_matrix_import.json", "r" + ) as file: data = json.load(file) label_list = list(NDJsonConverter.deserialize(data)) reserialized = list(NDJsonConverter.serialize(label_list)) - assert json.dumps(reserialized, - sort_keys=True) == json.dumps(data, sort_keys=True) + assert json.dumps(reserialized, sort_keys=True) == json.dumps( + data, sort_keys=True + ) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_mmc.py b/libs/labelbox/tests/data/serialization/ndjson/test_mmc.py index bc093b79b..69594ff73 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_mmc.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_mmc.py @@ -6,7 +6,7 @@ def test_message_task_annotation_serialization(): - with open('tests/data/assets/ndjson/mmc_import.json', 'r') as file: + with open("tests/data/assets/ndjson/mmc_import.json", "r") as file: data = json.load(file) deserialized = list(NDJsonConverter.deserialize(data)) @@ -16,14 +16,17 @@ def test_message_task_annotation_serialization(): def test_mesage_ranking_task_wrong_order_serialization(): - with open('tests/data/assets/ndjson/mmc_import.json', 'r') as file: + with open("tests/data/assets/ndjson/mmc_import.json", "r") as file: data = json.load(file) some_ranking_task = next( - task for task in data - if task["messageEvaluationTask"]["format"] == "message-ranking") + task + for task in data + if task["messageEvaluationTask"]["format"] == "message-ranking" + ) some_ranking_task["messageEvaluationTask"]["data"]["rankedMessages"][0][ - "order"] = 3 + "order" + ] = 3 with pytest.raises(ValueError): list(NDJsonConverter.deserialize([some_ranking_task])) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_ndlabel_subclass_matching.py b/libs/labelbox/tests/data/serialization/ndjson/test_ndlabel_subclass_matching.py index 1f51c307a..790bd87b3 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_ndlabel_subclass_matching.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_ndlabel_subclass_matching.py @@ -5,13 +5,15 @@ def test_bad_annotation_input(): - data = [{ - "test": 3 - }] + data = [{"test": 3}] with pytest.raises(ValueError): NDLabel(**{"annotations": data}) + def test_correct_annotation_input(): - with open('tests/data/assets/ndjson/pdf_import_name_only.json', 'r') as f: + with open("tests/data/assets/ndjson/pdf_import_name_only.json", "r") as f: data = json.load(f) - assert isinstance(NDLabel(**{"annotations": [data[0]]}).annotations[0], NDDocumentRectangle) + assert isinstance( + NDLabel(**{"annotations": [data[0]]}).annotations[0], + NDDocumentRectangle, + ) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_nested.py b/libs/labelbox/tests/data/serialization/ndjson/test_nested.py index 69fddf1ff..e0f0df0e6 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_nested.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_nested.py @@ -4,7 +4,7 @@ def test_nested(): - with open('tests/data/assets/ndjson/nested_import.json', 'r') as file: + with open("tests/data/assets/ndjson/nested_import.json", "r") as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) @@ -12,8 +12,9 @@ def test_nested(): def test_nested_name_only(): - with open('tests/data/assets/ndjson/nested_import_name_only.json', - 'r') as file: + with open( + "tests/data/assets/ndjson/nested_import_name_only.json", "r" + ) as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_polyline.py b/libs/labelbox/tests/data/serialization/ndjson/test_polyline.py index 933c378df..97d48a14e 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_polyline.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_polyline.py @@ -3,12 +3,15 @@ from labelbox.data.serialization.ndjson.converter import NDJsonConverter -@pytest.mark.parametrize("filename", [ - "tests/data/assets/ndjson/polyline_without_confidence_import.json", - "tests/data/assets/ndjson/polyline_import.json" -]) +@pytest.mark.parametrize( + "filename", + [ + "tests/data/assets/ndjson/polyline_without_confidence_import.json", + "tests/data/assets/ndjson/polyline_import.json", + ], +) def test_polyline_import(filename: str): - with open(filename, 'r') as file: + with open(filename, "r") as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_radio.py b/libs/labelbox/tests/data/serialization/ndjson/test_radio.py index 97cb073e0..bd80f9267 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_radio.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_radio.py @@ -1,6 +1,8 @@ import json from labelbox.data.annotation_types.annotation import ClassificationAnnotation -from labelbox.data.annotation_types.classification.classification import ClassificationAnswer +from labelbox.data.annotation_types.classification.classification import ( + ClassificationAnswer, +) from labelbox.data.annotation_types.classification.classification import Radio from labelbox.data.annotation_types.data.text import TextData from labelbox.data.annotation_types.label import Label @@ -19,17 +21,18 @@ def test_serialization_with_radio_min(): ClassificationAnnotation( name="radio_question_geo", value=Radio( - answer=ClassificationAnswer(name="first_radio_answer",))) - ]) + answer=ClassificationAnswer( + name="first_radio_answer", + ) + ), + ) + ], + ) expected = { - 'name': 'radio_question_geo', - 'answer': { - 'name': 'first_radio_answer' - }, - 'dataRow': { - 'id': 'bkj7z2q0b0000jx6x0q2q7q0d' - } + "name": "radio_question_geo", + "answer": {"name": "first_radio_answer"}, + "dataRow": {"id": "bkj7z2q0b0000jx6x0q2q7q0d"}, } serialized = NDJsonConverter.serialize([label]) res = next(serialized) @@ -47,43 +50,51 @@ def test_serialization_with_radio_min(): def test_serialization_with_radio_classification(): - label = Label(uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( - uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", - ), - annotations=[ - ClassificationAnnotation( - name="radio_question_geo", - confidence=0.5, - value=Radio(answer=ClassificationAnswer( - confidence=0.6, - name="first_radio_answer", - classifications=[ - ClassificationAnnotation( - name="sub_radio_question", - value=Radio(answer=ClassificationAnswer( - name="first_sub_radio_answer"))) - ]))) - ]) + label = Label( + uid="ckj7z2q0b0000jx6x0q2q7q0d", + data=TextData( + uid="bkj7z2q0b0000jx6x0q2q7q0d", + text="This is a test", + ), + annotations=[ + ClassificationAnnotation( + name="radio_question_geo", + confidence=0.5, + value=Radio( + answer=ClassificationAnswer( + confidence=0.6, + name="first_radio_answer", + classifications=[ + ClassificationAnnotation( + name="sub_radio_question", + value=Radio( + answer=ClassificationAnswer( + name="first_sub_radio_answer" + ) + ), + ) + ], + ) + ), + ) + ], + ) expected = { - 'confidence': 0.5, - 'name': 'radio_question_geo', - 'answer': { - 'confidence': - 0.6, - 'name': - 'first_radio_answer', - 'classifications': [{ - 'name': 'sub_radio_question', - 'answer': { - 'name': 'first_sub_radio_answer', + "confidence": 0.5, + "name": "radio_question_geo", + "answer": { + "confidence": 0.6, + "name": "first_radio_answer", + "classifications": [ + { + "name": "sub_radio_question", + "answer": { + "name": "first_sub_radio_answer", + }, } - }] + ], }, - 'dataRow': { - 'id': 'bkj7z2q0b0000jx6x0q2q7q0d' - } + "dataRow": {"id": "bkj7z2q0b0000jx6x0q2q7q0d"}, } serialized = NDJsonConverter.serialize([label]) @@ -94,5 +105,6 @@ def test_serialization_with_radio_classification(): deserialized = NDJsonConverter.deserialize([res]) res = next(deserialized) res.annotations[0].extra.pop("uuid") - assert res.annotations[0].model_dump(exclude_none=True) == label.annotations[0].model_dump(exclude_none=True) - + assert res.annotations[0].model_dump( + exclude_none=True + ) == label.annotations[0].model_dump(exclude_none=True) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py b/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py index c07dcc66d..66630dbb5 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py @@ -6,7 +6,7 @@ def test_rectangle(): - with open('tests/data/assets/ndjson/rectangle_import.json', 'r') as file: + with open("tests/data/assets/ndjson/rectangle_import.json", "r") as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) @@ -14,7 +14,7 @@ def test_rectangle(): def test_rectangle_inverted_start_end_points(): - with open('tests/data/assets/ndjson/rectangle_import.json', 'r') as file: + with open("tests/data/assets/ndjson/rectangle_import.json", "r") as file: data = json.load(file) bbox = lb_types.ObjectAnnotation( @@ -23,10 +23,10 @@ def test_rectangle_inverted_start_end_points(): start=lb_types.Point(x=81, y=69), end=lb_types.Point(x=38, y=28), ), - extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}) + extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}, + ) - label = lb_types.Label(data={"uid":DATAROW_ID}, - annotations=[bbox]) + label = lb_types.Label(data={"uid": DATAROW_ID}, annotations=[bbox]) res = list(NDJsonConverter.serialize([label])) assert res == data @@ -40,18 +40,20 @@ def test_rectangle_inverted_start_end_points(): extra={ "uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72", "page": None, - "unit": None - }) + "unit": None, + }, + ) - label = lb_types.Label(data={"uid":DATAROW_ID}, - annotations=[expected_bbox]) + label = lb_types.Label( + data={"uid": DATAROW_ID}, annotations=[expected_bbox] + ) res = list(NDJsonConverter.deserialize(res)) assert res == [label] def test_rectangle_mixed_start_end_points(): - with open('tests/data/assets/ndjson/rectangle_import.json', 'r') as file: + with open("tests/data/assets/ndjson/rectangle_import.json", "r") as file: data = json.load(file) bbox = lb_types.ObjectAnnotation( @@ -60,10 +62,10 @@ def test_rectangle_mixed_start_end_points(): start=lb_types.Point(x=81, y=28), end=lb_types.Point(x=38, y=69), ), - extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}) + extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}, + ) - label = lb_types.Label(data={"uid":DATAROW_ID}, - annotations=[bbox]) + label = lb_types.Label(data={"uid": DATAROW_ID}, annotations=[bbox]) res = list(NDJsonConverter.serialize([label])) assert res == data @@ -77,11 +79,11 @@ def test_rectangle_mixed_start_end_points(): extra={ "uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72", "page": None, - "unit": None - }) + "unit": None, + }, + ) - label = lb_types.Label(data={"uid":DATAROW_ID}, - annotations=[bbox]) + label = lb_types.Label(data={"uid": DATAROW_ID}, annotations=[bbox]) res = list(NDJsonConverter.deserialize(res)) assert res == [label] @@ -94,13 +96,13 @@ def test_benchmark_reference_label_flag_enabled(): start=lb_types.Point(x=81, y=28), end=lb_types.Point(x=38, y=69), ), - extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"} + extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}, ) label = lb_types.Label( - data={"uid":DATAROW_ID}, + data={"uid": DATAROW_ID}, annotations=[bbox], - is_benchmark_reference=True + is_benchmark_reference=True, ) res = list(NDJsonConverter.serialize([label])) @@ -114,13 +116,13 @@ def test_benchmark_reference_label_flag_disabled(): start=lb_types.Point(x=81, y=28), end=lb_types.Point(x=38, y=69), ), - extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"} + extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}, ) label = lb_types.Label( - data={"uid":DATAROW_ID}, + data={"uid": DATAROW_ID}, annotations=[bbox], - is_benchmark_reference=False + is_benchmark_reference=False, ) res = list(NDJsonConverter.serialize([label])) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_relationship.py b/libs/labelbox/tests/data/serialization/ndjson/test_relationship.py index 9ede41d2c..f33719035 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_relationship.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_relationship.py @@ -30,10 +30,14 @@ def test_relationship(): ] assert res_relationship_second_annotation - assert res_relationship_second_annotation["relationship"][ - "source"] != res_relationship_annotation["relationship"]["source"] - assert res_relationship_second_annotation["relationship"][ - "target"] != res_relationship_annotation["relationship"]["target"] + assert ( + res_relationship_second_annotation["relationship"]["source"] + != res_relationship_annotation["relationship"]["source"] + ) + assert ( + res_relationship_second_annotation["relationship"]["target"] + != res_relationship_annotation["relationship"]["target"] + ) assert res_relationship_second_annotation["relationship"]["source"] in [ annot["uuid"] for annot in res_source_and_target ] diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_text.py b/libs/labelbox/tests/data/serialization/ndjson/test_text.py index 534068e14..d5e81c51a 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_text.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_text.py @@ -1,5 +1,9 @@ from labelbox.data.annotation_types.annotation import ClassificationAnnotation -from labelbox.data.annotation_types.classification.classification import ClassificationAnswer, Radio, Text +from labelbox.data.annotation_types.classification.classification import ( + ClassificationAnswer, + Radio, + Text, +) from labelbox.data.annotation_types.data.text import TextData from labelbox.data.annotation_types.label import Label @@ -7,24 +11,29 @@ def test_serialization(): - label = Label(uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( - uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", - ), - annotations=[ - ClassificationAnnotation( - name="radio_question_geo", - confidence=0.5, - value=Text(answer="first_radio_answer")) - ]) + label = Label( + uid="ckj7z2q0b0000jx6x0q2q7q0d", + data=TextData( + uid="bkj7z2q0b0000jx6x0q2q7q0d", + text="This is a test", + ), + annotations=[ + ClassificationAnnotation( + name="radio_question_geo", + confidence=0.5, + value=Text(answer="first_radio_answer"), + ) + ], + ) serialized = NDJsonConverter.serialize([label]) res = next(serialized) - assert 'confidence' not in res # because confidence needs to be set on the annotation itself - assert res['name'] == "radio_question_geo" - assert res['answer'] == "first_radio_answer" - assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d" + assert ( + "confidence" not in res + ) # because confidence needs to be set on the annotation itself + assert res["name"] == "radio_question_geo" + assert res["answer"] == "first_radio_answer" + assert res["dataRow"]["id"] == "bkj7z2q0b0000jx6x0q2q7q0d" deserialized = NDJsonConverter.deserialize([res]) res = next(deserialized) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_text_entity.py b/libs/labelbox/tests/data/serialization/ndjson/test_text_entity.py index f62d87ebc..3e856f001 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_text_entity.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_text_entity.py @@ -5,12 +5,15 @@ from labelbox.data.serialization.ndjson.converter import NDJsonConverter -@pytest.mark.parametrize("filename", [ - "tests/data/assets/ndjson/text_entity_import.json", - "tests/data/assets/ndjson/text_entity_without_confidence_import.json" -]) +@pytest.mark.parametrize( + "filename", + [ + "tests/data/assets/ndjson/text_entity_import.json", + "tests/data/assets/ndjson/text_entity_without_confidence_import.json", + ], +) def test_text_entity_import(filename: str): - with open(filename, 'r') as file: + with open(filename, "r") as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_video.py b/libs/labelbox/tests/data/serialization/ndjson/test_video.py index 4b90a8060..c7a6535c4 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_video.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_video.py @@ -1,6 +1,11 @@ import json from labelbox.client import Client -from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnnotation, ClassificationAnswer, Radio +from labelbox.data.annotation_types.classification.classification import ( + Checklist, + ClassificationAnnotation, + ClassificationAnswer, + Radio, +) from labelbox.data.annotation_types.data.video import VideoData from labelbox.data.annotation_types.geometry.line import Line from labelbox.data.annotation_types.geometry.point import Point @@ -16,29 +21,31 @@ def test_video(): - with open('tests/data/assets/ndjson/video_import.json', 'r') as file: + with open("tests/data/assets/ndjson/video_import.json", "r") as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) - - data = sorted(data, key=itemgetter('uuid')) - res = sorted(res, key=itemgetter('uuid')) + + data = sorted(data, key=itemgetter("uuid")) + res = sorted(res, key=itemgetter("uuid")) pairs = zip(data, res) for data, res in pairs: assert data == res + def test_video_name_only(): - with open('tests/data/assets/ndjson/video_import_name_only.json', - 'r') as file: + with open( + "tests/data/assets/ndjson/video_import_name_only.json", "r" + ) as file: data = json.load(file) res = list(NDJsonConverter.deserialize(data)) res = list(NDJsonConverter.serialize(res)) - - data = sorted(data, key=itemgetter('uuid')) - res = sorted(res, key=itemgetter('uuid')) + + data = sorted(data, key=itemgetter("uuid")) + res = sorted(res, key=itemgetter("uuid")) pairs = zip(data, res) for data, res in pairs: @@ -47,54 +54,60 @@ def test_video_name_only(): def test_video_classification_global_subclassifications(): label = Label( - data=VideoData(global_key="sample-video-4.mp4",), + data=VideoData( + global_key="sample-video-4.mp4", + ), annotations=[ ClassificationAnnotation( - name='radio_question_nested', - value=Radio(answer=ClassificationAnswer( - name='first_radio_question')), + name="radio_question_nested", + value=Radio( + answer=ClassificationAnswer(name="first_radio_question") + ), ), ClassificationAnnotation( - name='nested_checklist_question', + name="nested_checklist_question", value=Checklist( - name='checklist', + name="checklist", answer=[ ClassificationAnswer( - name='first_checklist_answer', + name="first_checklist_answer", classifications=[ ClassificationAnnotation( - name='sub_checklist_question', - value=Radio(answer=ClassificationAnswer( - name='first_sub_checklist_answer'))) - ]) - ])) - ]) + name="sub_checklist_question", + value=Radio( + answer=ClassificationAnswer( + name="first_sub_checklist_answer" + ) + ), + ) + ], + ) + ], + ), + ), + ], + ) expected_first_annotation = { - 'name': 'radio_question_nested', - 'answer': { - 'name': 'first_radio_question' - }, - 'dataRow': { - 'globalKey': 'sample-video-4.mp4' - } + "name": "radio_question_nested", + "answer": {"name": "first_radio_question"}, + "dataRow": {"globalKey": "sample-video-4.mp4"}, } expected_second_annotation = nested_checklist_annotation_ndjson = { "name": "nested_checklist_question", - "answer": [{ - "name": - "first_checklist_answer", - "classifications": [{ - "name": "sub_checklist_question", - "answer": { - "name": "first_sub_checklist_answer" - } - }] - }], - 'dataRow': { - 'globalKey': 'sample-video-4.mp4' - } + "answer": [ + { + "name": "first_checklist_answer", + "classifications": [ + { + "name": "sub_checklist_question", + "answer": {"name": "first_sub_checklist_answer"}, + } + ], + } + ], + "dataRow": {"globalKey": "sample-video-4.mp4"}, } serialized = NDJsonConverter.serialize([label]) @@ -123,18 +136,27 @@ def test_video_classification_nesting_bbox(): ), classifications=[ ClassificationAnnotation( - name='radio_question_nested', - value=Radio(answer=ClassificationAnswer( - name='first_radio_question', - classifications=[ - ClassificationAnnotation(name='sub_question_radio', - value=Checklist(answer=[ - ClassificationAnswer( - name='sub_answer') - ])) - ])), + name="radio_question_nested", + value=Radio( + answer=ClassificationAnswer( + name="first_radio_question", + classifications=[ + ClassificationAnnotation( + name="sub_question_radio", + value=Checklist( + answer=[ + ClassificationAnswer( + name="sub_answer" + ) + ] + ), + ) + ], + ) + ), ) - ]), + ], + ), VideoObjectAnnotation( name="bbox_video", keyframe=True, @@ -146,18 +168,27 @@ def test_video_classification_nesting_bbox(): ), classifications=[ ClassificationAnnotation( - name='nested_checklist_question', - value=Checklist(answer=[ - ClassificationAnswer( - name='first_checklist_answer', - classifications=[ - ClassificationAnnotation( - name='sub_checklist_question', - value=Radio(answer=ClassificationAnswer( - name='first_sub_checklist_answer'))) - ]) - ])) - ]), + name="nested_checklist_question", + value=Checklist( + answer=[ + ClassificationAnswer( + name="first_checklist_answer", + classifications=[ + ClassificationAnnotation( + name="sub_checklist_question", + value=Radio( + answer=ClassificationAnswer( + name="first_sub_checklist_answer" + ) + ), + ) + ], + ) + ] + ), + ) + ], + ), VideoObjectAnnotation( name="bbox_video", keyframe=True, @@ -166,76 +197,91 @@ def test_video_classification_nesting_bbox(): value=Rectangle( start=Point(x=146.0, y=98.0), # Top left end=Point(x=382.0, y=341.0), # Bottom right - )) + ), + ), ] - expected = [{ - 'dataRow': { - 'globalKey': 'sample-video-4.mp4' - }, - 'name': - 'bbox_video', - 'classifications': [], - 'segments': [{ - 'keyframes': [{ - 'frame': - 13, - 'bbox': { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - }, - 'classifications': [{ - 'name': 'radio_question_nested', - 'answer': { - 'name': - 'first_radio_question', - 'classifications': [{ - 'name': 'sub_question_radio', - 'answer': [{ - 'name': 'sub_answer' - }] - }] - } - }] - }, { - 'frame': - 15, - 'bbox': { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - }, - 'classifications': [{ - 'name': - 'nested_checklist_question', - 'answer': [{ - 'name': - 'first_checklist_answer', - 'classifications': [{ - 'name': 'sub_checklist_question', - 'answer': { - 'name': 'first_sub_checklist_answer' - } - }] - }] - }] - }, { - 'frame': 19, - 'bbox': { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - }, - 'classifications': [] - }] - }] - }] - - label = Label(data=VideoData(global_key="sample-video-4.mp4",), - annotations=bbox_annotation) + expected = [ + { + "dataRow": {"globalKey": "sample-video-4.mp4"}, + "name": "bbox_video", + "classifications": [], + "segments": [ + { + "keyframes": [ + { + "frame": 13, + "bbox": { + "top": 98.0, + "left": 146.0, + "height": 243.0, + "width": 236.0, + }, + "classifications": [ + { + "name": "radio_question_nested", + "answer": { + "name": "first_radio_question", + "classifications": [ + { + "name": "sub_question_radio", + "answer": [ + {"name": "sub_answer"} + ], + } + ], + }, + } + ], + }, + { + "frame": 15, + "bbox": { + "top": 98.0, + "left": 146.0, + "height": 243.0, + "width": 236.0, + }, + "classifications": [ + { + "name": "nested_checklist_question", + "answer": [ + { + "name": "first_checklist_answer", + "classifications": [ + { + "name": "sub_checklist_question", + "answer": { + "name": "first_sub_checklist_answer" + }, + } + ], + } + ], + } + ], + }, + { + "frame": 19, + "bbox": { + "top": 98.0, + "left": 146.0, + "height": 243.0, + "width": 236.0, + }, + "classifications": [], + }, + ] + } + ], + } + ] + + label = Label( + data=VideoData( + global_key="sample-video-4.mp4", + ), + annotations=bbox_annotation, + ) serialized = NDJsonConverter.serialize([label]) res = [x for x in serialized] @@ -260,18 +306,27 @@ def test_video_classification_point(): value=Point(x=46.0, y=8.0), classifications=[ ClassificationAnnotation( - name='radio_question_nested', - value=Radio(answer=ClassificationAnswer( - name='first_radio_question', - classifications=[ - ClassificationAnnotation(name='sub_question_radio', - value=Checklist(answer=[ - ClassificationAnswer( - name='sub_answer') - ])) - ])), + name="radio_question_nested", + value=Radio( + answer=ClassificationAnswer( + name="first_radio_question", + classifications=[ + ClassificationAnnotation( + name="sub_question_radio", + value=Checklist( + answer=[ + ClassificationAnswer( + name="sub_answer" + ) + ] + ), + ) + ], + ) + ), ) - ]), + ], + ), VideoObjectAnnotation( name="bbox_video", keyframe=True, @@ -280,88 +335,111 @@ def test_video_classification_point(): value=Point(x=56.0, y=18.0), classifications=[ ClassificationAnnotation( - name='nested_checklist_question', - value=Checklist(answer=[ - ClassificationAnswer( - name='first_checklist_answer', - classifications=[ - ClassificationAnnotation( - name='sub_checklist_question', - value=Radio(answer=ClassificationAnswer( - name='first_sub_checklist_answer'))) - ]) - ])) - ]), + name="nested_checklist_question", + value=Checklist( + answer=[ + ClassificationAnswer( + name="first_checklist_answer", + classifications=[ + ClassificationAnnotation( + name="sub_checklist_question", + value=Radio( + answer=ClassificationAnswer( + name="first_sub_checklist_answer" + ) + ), + ) + ], + ) + ] + ), + ) + ], + ), VideoObjectAnnotation( name="bbox_video", keyframe=True, frame=19, segment_index=0, value=Point(x=66.0, y=28.0), - ) + ), + ] + expected = [ + { + "dataRow": {"globalKey": "sample-video-4.mp4"}, + "name": "bbox_video", + "classifications": [], + "segments": [ + { + "keyframes": [ + { + "frame": 13, + "point": { + "x": 46.0, + "y": 8.0, + }, + "classifications": [ + { + "name": "radio_question_nested", + "answer": { + "name": "first_radio_question", + "classifications": [ + { + "name": "sub_question_radio", + "answer": [ + {"name": "sub_answer"} + ], + } + ], + }, + } + ], + }, + { + "frame": 15, + "point": { + "x": 56.0, + "y": 18.0, + }, + "classifications": [ + { + "name": "nested_checklist_question", + "answer": [ + { + "name": "first_checklist_answer", + "classifications": [ + { + "name": "sub_checklist_question", + "answer": { + "name": "first_sub_checklist_answer" + }, + } + ], + } + ], + } + ], + }, + { + "frame": 19, + "point": { + "x": 66.0, + "y": 28.0, + }, + "classifications": [], + }, + ] + } + ], + } ] - expected = [{ - 'dataRow': { - 'globalKey': 'sample-video-4.mp4' - }, - 'name': - 'bbox_video', - 'classifications': [], - 'segments': [{ - 'keyframes': [{ - 'frame': - 13, - 'point': { - 'x': 46.0, - 'y': 8.0, - }, - 'classifications': [{ - 'name': 'radio_question_nested', - 'answer': { - 'name': - 'first_radio_question', - 'classifications': [{ - 'name': 'sub_question_radio', - 'answer': [{ - 'name': 'sub_answer' - }] - }] - } - }] - }, { - 'frame': - 15, - 'point': { - 'x': 56.0, - 'y': 18.0, - }, - 'classifications': [{ - 'name': - 'nested_checklist_question', - 'answer': [{ - 'name': - 'first_checklist_answer', - 'classifications': [{ - 'name': 'sub_checklist_question', - 'answer': { - 'name': 'first_sub_checklist_answer' - } - }] - }] - }] - }, { - 'frame': 19, - 'point': { - 'x': 66.0, - 'y': 28.0, - }, - 'classifications': [] - }] - }] - }] - - label = Label(data=VideoData(global_key="sample-video-4.mp4",), - annotations=bbox_annotation) + + label = Label( + data=VideoData( + global_key="sample-video-4.mp4", + ), + annotations=bbox_annotation, + ) serialized = NDJsonConverter.serialize([label]) res = [x for x in serialized] @@ -382,123 +460,161 @@ def test_video_classification_frameline(): keyframe=True, frame=13, segment_index=0, - value=Line( - points=[Point(x=8, y=10), Point(x=10, y=9)]), + value=Line(points=[Point(x=8, y=10), Point(x=10, y=9)]), classifications=[ ClassificationAnnotation( - name='radio_question_nested', - value=Radio(answer=ClassificationAnswer( - name='first_radio_question', - classifications=[ - ClassificationAnnotation(name='sub_question_radio', - value=Checklist(answer=[ - ClassificationAnswer( - name='sub_answer') - ])) - ])), + name="radio_question_nested", + value=Radio( + answer=ClassificationAnswer( + name="first_radio_question", + classifications=[ + ClassificationAnnotation( + name="sub_question_radio", + value=Checklist( + answer=[ + ClassificationAnswer( + name="sub_answer" + ) + ] + ), + ) + ], + ) + ), ) - ]), + ], + ), VideoObjectAnnotation( name="bbox_video", keyframe=True, frame=15, segment_index=0, - value=Line( - points=[Point(x=18, y=20), Point(x=20, y=19)]), + value=Line(points=[Point(x=18, y=20), Point(x=20, y=19)]), classifications=[ ClassificationAnnotation( - name='nested_checklist_question', - value=Checklist(answer=[ - ClassificationAnswer( - name='first_checklist_answer', - classifications=[ - ClassificationAnnotation( - name='sub_checklist_question', - value=Radio(answer=ClassificationAnswer( - name='first_sub_checklist_answer'))) - ]) - ])) - ]), + name="nested_checklist_question", + value=Checklist( + answer=[ + ClassificationAnswer( + name="first_checklist_answer", + classifications=[ + ClassificationAnnotation( + name="sub_checklist_question", + value=Radio( + answer=ClassificationAnswer( + name="first_sub_checklist_answer" + ) + ), + ) + ], + ) + ] + ), + ) + ], + ), VideoObjectAnnotation( name="bbox_video", keyframe=True, frame=19, segment_index=0, - value=Line( - points=[Point(x=28, y=30), Point(x=30, y=29)]), - ) + value=Line(points=[Point(x=28, y=30), Point(x=30, y=29)]), + ), ] - expected = [{ - 'dataRow': { - 'globalKey': 'sample-video-4.mp4' - }, - 'name': - 'bbox_video', - 'classifications': [], - 'segments': [{ - 'keyframes': [{ - 'frame': - 13, - 'line': [{ - 'x': 8.0, - 'y': 10.0, - }, { - 'x': 10.0, - 'y': 9.0, - }], - 'classifications': [{ - 'name': 'radio_question_nested', - 'answer': { - 'name': - 'first_radio_question', - 'classifications': [{ - 'name': 'sub_question_radio', - 'answer': [{ - 'name': 'sub_answer' - }] - }] - } - }] - }, { - 'frame': - 15, - 'line': [{ - 'x': 18.0, - 'y': 20.0, - }, { - 'x': 20.0, - 'y': 19.0, - }], - 'classifications': [{ - 'name': - 'nested_checklist_question', - 'answer': [{ - 'name': - 'first_checklist_answer', - 'classifications': [{ - 'name': 'sub_checklist_question', - 'answer': { - 'name': 'first_sub_checklist_answer' - } - }] - }] - }] - }, { - 'frame': 19, - 'line': [{ - 'x': 28.0, - 'y': 30.0, - }, { - 'x': 30.0, - 'y': 29.0, - }], - 'classifications': [] - }] - }] - }] - - label = Label(data=VideoData(global_key="sample-video-4.mp4",), - annotations=bbox_annotation) + expected = [ + { + "dataRow": {"globalKey": "sample-video-4.mp4"}, + "name": "bbox_video", + "classifications": [], + "segments": [ + { + "keyframes": [ + { + "frame": 13, + "line": [ + { + "x": 8.0, + "y": 10.0, + }, + { + "x": 10.0, + "y": 9.0, + }, + ], + "classifications": [ + { + "name": "radio_question_nested", + "answer": { + "name": "first_radio_question", + "classifications": [ + { + "name": "sub_question_radio", + "answer": [ + {"name": "sub_answer"} + ], + } + ], + }, + } + ], + }, + { + "frame": 15, + "line": [ + { + "x": 18.0, + "y": 20.0, + }, + { + "x": 20.0, + "y": 19.0, + }, + ], + "classifications": [ + { + "name": "nested_checklist_question", + "answer": [ + { + "name": "first_checklist_answer", + "classifications": [ + { + "name": "sub_checklist_question", + "answer": { + "name": "first_sub_checklist_answer" + }, + } + ], + } + ], + } + ], + }, + { + "frame": 19, + "line": [ + { + "x": 28.0, + "y": 30.0, + }, + { + "x": 30.0, + "y": 29.0, + }, + ], + "classifications": [], + }, + ] + } + ], + } + ] + + label = Label( + data=VideoData( + global_key="sample-video-4.mp4", + ), + annotations=bbox_annotation, + ) serialized = NDJsonConverter.serialize([label]) res = [x for x in serialized] assert res == expected diff --git a/libs/labelbox/tests/data/test_data_row_metadata.py b/libs/labelbox/tests/data/test_data_row_metadata.py index 1cadc4376..9a3690776 100644 --- a/libs/labelbox/tests/data/test_data_row_metadata.py +++ b/libs/labelbox/tests/data/test_data_row_metadata.py @@ -6,7 +6,13 @@ from labelbox import Dataset from labelbox.exceptions import MalformedQueryException from labelbox.schema.identifiables import GlobalKeys, UniqueIds -from labelbox.schema.data_row_metadata import DataRowMetadataField, DataRowMetadata, DataRowMetadataKind, DataRowMetadataOntology, _parse_metadata_schema +from labelbox.schema.data_row_metadata import ( + DataRowMetadataField, + DataRowMetadata, + DataRowMetadataKind, + DataRowMetadataOntology, + _parse_metadata_schema, +) INVALID_SCHEMA_ID = "1" * 25 FAKE_SCHEMA_ID = "0" * 25 @@ -16,13 +22,13 @@ TEST_SPLIT_ID = "cko8scbz70005h2dkastwhgqt" TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh" CAPTURE_DT_SCHEMA_ID = "cko8sdzv70006h2dk8jg64zvb" -CUSTOM_TEXT_SCHEMA_NAME = 'custom_text' +CUSTOM_TEXT_SCHEMA_NAME = "custom_text" FAKE_NUMBER_FIELD = { "id": FAKE_SCHEMA_ID, "name": "number", - "kind": 'CustomMetadataNumber', - "reserved": False + "kind": "CustomMetadataNumber", + "reserved": False, } @@ -42,12 +48,12 @@ def mdo(client): @pytest.fixture def big_dataset(dataset: Dataset, image_url): - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image" - }, - ] * 5) + task = dataset.create_data_rows( + [ + {"row_data": image_url, "external_id": "my-image"}, + ] + * 5 + ) task.wait_till_done() yield dataset @@ -61,11 +67,13 @@ def make_metadata(dr_id: str = None, gk: str = None) -> DataRowMetadata: global_key=gk, data_row_id=dr_id, fields=[ - DataRowMetadataField(schema_id=SPLIT_SCHEMA_ID, - value=TEST_SPLIT_ID), + DataRowMetadataField( + schema_id=SPLIT_SCHEMA_ID, value=TEST_SPLIT_ID + ), DataRowMetadataField(schema_id=CAPTURE_DT_SCHEMA_ID, value=time), DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value=msg), - ]) + ], + ) return metadata @@ -73,29 +81,29 @@ def make_named_metadata(dr_id) -> DataRowMetadata: msg = "A message" time = datetime.utcnow() - metadata = DataRowMetadata(data_row_id=dr_id, - fields=[ - DataRowMetadataField(name='split', - value=TEST_SPLIT_ID), - DataRowMetadataField(name='captureDateTime', - value=time), - DataRowMetadataField( - name=CUSTOM_TEXT_SCHEMA_NAME, value=msg), - ]) + metadata = DataRowMetadata( + data_row_id=dr_id, + fields=[ + DataRowMetadataField(name="split", value=TEST_SPLIT_ID), + DataRowMetadataField(name="captureDateTime", value=time), + DataRowMetadataField(name=CUSTOM_TEXT_SCHEMA_NAME, value=msg), + ], + ) return metadata @pytest.mark.skip(reason="broken export v1 api, to be retired soon") -def test_export_empty_metadata(client, configured_project_with_label, - wait_for_data_row_processing): +def test_export_empty_metadata( + client, configured_project_with_label, wait_for_data_row_processing +): project, _, data_row, _ = configured_project_with_label data_row = wait_for_data_row_processing(client, data_row) - + export_task = project.export(params={"metadata_fields": True}) export_task.wait_till_done() stream = export_task.get_buffered_stream() data_row = [data_row.json for data_row in stream][0] - + assert data_row["metadata_fields"] == [] @@ -134,9 +142,11 @@ def test_get_datarow_metadata_ontology(mdo): value=datetime.utcnow(), ), DataRowMetadataField(schema_id=split.parent, value=split.uid), - DataRowMetadataField(schema_id=mdo.reserved_by_name["tag"].uid, - value="hello-world"), - ]) + DataRowMetadataField( + schema_id=mdo.reserved_by_name["tag"].uid, value="hello-world" + ), + ], + ) def test_bulk_upsert_datarow_metadata(data_row, mdo: DataRowMetadataOntology): @@ -148,7 +158,8 @@ def test_bulk_upsert_datarow_metadata(data_row, mdo: DataRowMetadataOntology): def test_bulk_upsert_datarow_metadata_by_globalkey( - data_rows, mdo: DataRowMetadataOntology): + data_rows, mdo: DataRowMetadataOntology +): global_keys = [data_row.global_key for data_row in data_rows] metadata = [make_metadata(gk=global_key) for global_key in global_keys] errors = mdo.bulk_upsert(metadata) @@ -169,8 +180,9 @@ def test_large_bulk_upsert_datarow_metadata(big_dataset, mdo): for metadata in mdo.bulk_export(data_row_ids) } for data_row_id in data_row_ids: - assert len([f for f in metadata_lookup.get(data_row_id).fields - ]), metadata_lookup.get(data_row_id).fields + assert len( + [f for f in metadata_lookup.get(data_row_id).fields] + ), metadata_lookup.get(data_row_id).fields def test_upsert_datarow_metadata_by_name(data_row, mdo): @@ -182,16 +194,18 @@ def test_upsert_datarow_metadata_by_name(data_row, mdo): metadata.data_row_id: metadata for metadata in mdo.bulk_export([data_row.uid]) } - assert len([f for f in metadata_lookup.get(data_row.uid).fields - ]), metadata_lookup.get(data_row.uid).fields + assert len( + [f for f in metadata_lookup.get(data_row.uid).fields] + ), metadata_lookup.get(data_row.uid).fields def test_upsert_datarow_metadata_option_by_name(data_row, mdo): - metadata = DataRowMetadata(data_row_id=data_row.uid, - fields=[ - DataRowMetadataField(name='split', - value='test'), - ]) + metadata = DataRowMetadata( + data_row_id=data_row.uid, + fields=[ + DataRowMetadataField(name="split", value="test"), + ], + ) errors = mdo.bulk_upsert([metadata]) assert len(errors) == 0 @@ -199,16 +213,17 @@ def test_upsert_datarow_metadata_option_by_name(data_row, mdo): assert len(datarows[0].fields) == 1 metadata = datarows[0].fields[0] assert metadata.schema_id == SPLIT_SCHEMA_ID - assert metadata.name == 'test' + assert metadata.name == "test" assert metadata.value == TEST_SPLIT_ID def test_upsert_datarow_metadata_option_by_incorrect_name(data_row, mdo): - metadata = DataRowMetadata(data_row_id=data_row.uid, - fields=[ - DataRowMetadataField(name='split', - value='test1'), - ]) + metadata = DataRowMetadata( + data_row_id=data_row.uid, + fields=[ + DataRowMetadataField(name="split", value="test1"), + ], + ) with pytest.raises(KeyError): mdo.bulk_upsert([metadata]) @@ -216,55 +231,47 @@ def test_upsert_datarow_metadata_option_by_incorrect_name(data_row, mdo): def test_raise_enum_upsert_schema_error(data_row, mdo): """Setting an option id as the schema id will raise a Value Error""" - metadata = DataRowMetadata(data_row_id=data_row.uid, - fields=[ - DataRowMetadataField(schema_id=TEST_SPLIT_ID, - value=SPLIT_SCHEMA_ID), - ]) + metadata = DataRowMetadata( + data_row_id=data_row.uid, + fields=[ + DataRowMetadataField( + schema_id=TEST_SPLIT_ID, value=SPLIT_SCHEMA_ID + ), + ], + ) with pytest.raises(ValueError): mdo.bulk_upsert([metadata]) def test_upsert_non_existent_schema_id(data_row, mdo): """Raise error on non-existent schema id""" - metadata = DataRowMetadata(data_row_id=data_row.uid, - fields=[ - DataRowMetadataField( - schema_id=INVALID_SCHEMA_ID, - value="message"), - ]) + metadata = DataRowMetadata( + data_row_id=data_row.uid, + fields=[ + DataRowMetadataField(schema_id=INVALID_SCHEMA_ID, value="message"), + ], + ) with pytest.raises(ValueError): mdo.bulk_upsert([metadata]) def test_parse_raw_metadata(mdo): example = { - 'dataRowId': - 'ckr6kkfx801ui0yrtg9fje8xh', - 'globalKey': - 'global-key-1', - 'fields': [ - { - 'schemaId': 'cko8s9r5v0001h2dk9elqdidh', - 'value': 'my-new-message' - }, - { - 'schemaId': 'cko8sbczn0002h2dkdaxb5kal', - 'value': {} - }, + "dataRowId": "ckr6kkfx801ui0yrtg9fje8xh", + "globalKey": "global-key-1", + "fields": [ { - 'schemaId': 'cko8sbscr0003h2dk04w86hof', - 'value': {} + "schemaId": "cko8s9r5v0001h2dk9elqdidh", + "value": "my-new-message", }, + {"schemaId": "cko8sbczn0002h2dkdaxb5kal", "value": {}}, + {"schemaId": "cko8sbscr0003h2dk04w86hof", "value": {}}, { - 'schemaId': 'cko8sdzv70006h2dk8jg64zvb', - 'value': '2021-07-20T21:41:14.606710Z' + "schemaId": "cko8sdzv70006h2dk8jg64zvb", + "value": "2021-07-20T21:41:14.606710Z", }, - { - 'schemaId': FAKE_SCHEMA_ID, - 'value': 0.5 - }, - ] + {"schemaId": FAKE_SCHEMA_ID, "value": 0.5}, + ], } parsed = mdo.parse_metadata([example]) @@ -281,26 +288,14 @@ def test_parse_raw_metadata(mdo): def test_parse_raw_metadata_fields(mdo): example = [ + {"schemaId": "cko8s9r5v0001h2dk9elqdidh", "value": "my-new-message"}, + {"schemaId": "cko8sbczn0002h2dkdaxb5kal", "value": {}}, + {"schemaId": "cko8sbscr0003h2dk04w86hof", "value": {}}, { - 'schemaId': 'cko8s9r5v0001h2dk9elqdidh', - 'value': 'my-new-message' - }, - { - 'schemaId': 'cko8sbczn0002h2dkdaxb5kal', - 'value': {} - }, - { - 'schemaId': 'cko8sbscr0003h2dk04w86hof', - 'value': {} - }, - { - 'schemaId': 'cko8sdzv70006h2dk8jg64zvb', - 'value': '2021-07-20T21:41:14.606710Z' - }, - { - 'schemaId': FAKE_SCHEMA_ID, - 'value': 0.5 + "schemaId": "cko8sdzv70006h2dk8jg64zvb", + "value": "2021-07-20T21:41:14.606710Z", }, + {"schemaId": FAKE_SCHEMA_ID, "value": 0.5}, ] parsed = mdo.parse_metadata_fields(example) @@ -312,35 +307,36 @@ def test_parse_raw_metadata_fields(mdo): def test_parse_metadata_schema(): unparsed = { - 'id': - 'cl467a4ec0046076g7s9yheoa', - 'name': - 'enum metadata', - 'kind': - 'CustomMetadataEnum', - 'options': [{ - 'id': 'cl467a4ec0047076ggjneeruy', - 'name': 'option1', - 'kind': 'CustomMetadataEnumOption' - }, { - 'id': 'cl4qa31u0009e078p5m280jer', - 'name': 'option2', - 'kind': 'CustomMetadataEnumOption' - }] + "id": "cl467a4ec0046076g7s9yheoa", + "name": "enum metadata", + "kind": "CustomMetadataEnum", + "options": [ + { + "id": "cl467a4ec0047076ggjneeruy", + "name": "option1", + "kind": "CustomMetadataEnumOption", + }, + { + "id": "cl4qa31u0009e078p5m280jer", + "name": "option2", + "kind": "CustomMetadataEnumOption", + }, + ], } parsed = _parse_metadata_schema(unparsed) - assert parsed.uid == 'cl467a4ec0046076g7s9yheoa' - assert parsed.name == 'enum metadata' + assert parsed.uid == "cl467a4ec0046076g7s9yheoa" + assert parsed.name == "enum metadata" assert parsed.kind == DataRowMetadataKind.enum assert len(parsed.options) == 2 - assert parsed.options[0].uid == 'cl467a4ec0047076ggjneeruy' + assert parsed.options[0].uid == "cl467a4ec0047076ggjneeruy" assert parsed.options[0].kind == DataRowMetadataKind.option def test_create_schema(mdo): metadata_name = str(uuid.uuid4()) - created_schema = mdo.create_schema(metadata_name, DataRowMetadataKind.enum, - ["option 1", "option 2"]) + created_schema = mdo.create_schema( + metadata_name, DataRowMetadataKind.enum, ["option 1", "option 2"] + ) assert created_schema.name == metadata_name assert created_schema.kind == DataRowMetadataKind.enum assert len(created_schema.options) == 2 @@ -350,10 +346,12 @@ def test_create_schema(mdo): def test_update_schema(mdo): metadata_name = str(uuid.uuid4()) - created_schema = mdo.create_schema(metadata_name, DataRowMetadataKind.enum, - ["option 1", "option 2"]) - updated_schema = mdo.update_schema(metadata_name, - f"{metadata_name}_updated") + created_schema = mdo.create_schema( + metadata_name, DataRowMetadataKind.enum, ["option 1", "option 2"] + ) + updated_schema = mdo.update_schema( + metadata_name, f"{metadata_name}_updated" + ) assert updated_schema.name == f"{metadata_name}_updated" assert updated_schema.uid == created_schema.uid assert updated_schema.kind == DataRowMetadataKind.enum @@ -362,10 +360,12 @@ def test_update_schema(mdo): def test_update_enum_options(mdo): metadata_name = str(uuid.uuid4()) - created_schema = mdo.create_schema(metadata_name, DataRowMetadataKind.enum, - ["option 1", "option 2"]) - updated_schema = mdo.update_enum_option(metadata_name, "option 1", - "option 3") + created_schema = mdo.create_schema( + metadata_name, DataRowMetadataKind.enum, ["option 1", "option 2"] + ) + updated_schema = mdo.update_enum_option( + metadata_name, "option 1", "option 3" + ) assert updated_schema.name == metadata_name assert updated_schema.uid == created_schema.uid assert updated_schema.kind == DataRowMetadataKind.enum @@ -376,23 +376,28 @@ def test_update_enum_options(mdo): def test_delete_schema(mdo): metadata_name = str(uuid.uuid4()) - created_schema = mdo.create_schema(metadata_name, - DataRowMetadataKind.string) + created_schema = mdo.create_schema( + metadata_name, DataRowMetadataKind.string + ) status = mdo.delete_schema(created_schema.name) mdo.refresh_ontology() assert status assert metadata_name not in mdo.custom_by_name -@pytest.mark.parametrize('datetime_str', - ['2011-11-04T00:05:23Z', '2011-11-04T00:05:23+00:00']) +@pytest.mark.parametrize( + "datetime_str", ["2011-11-04T00:05:23Z", "2011-11-04T00:05:23+00:00"] +) def test_upsert_datarow_date_metadata(data_row, mdo, datetime_str): metadata = [ - DataRowMetadata(data_row_id=data_row.uid, - fields=[ - DataRowMetadataField(name='captureDateTime', - value=datetime_str), - ]) + DataRowMetadata( + data_row_id=data_row.uid, + fields=[ + DataRowMetadataField( + name="captureDateTime", value=datetime_str + ), + ], + ) ] errors = mdo.bulk_upsert(metadata) assert len(errors) == 0 @@ -401,18 +406,22 @@ def test_upsert_datarow_date_metadata(data_row, mdo, datetime_str): assert f"{metadata[0].fields[0].value}" == "2011-11-04 00:05:23+00:00" -@pytest.mark.parametrize('datetime_str', - ['2011-11-04T00:05:23Z', '2011-11-04T00:05:23+00:00']) +@pytest.mark.parametrize( + "datetime_str", ["2011-11-04T00:05:23Z", "2011-11-04T00:05:23+00:00"] +) def test_create_data_row_with_metadata(dataset, image_url, datetime_str): client = dataset.client assert len(list(dataset.data_rows())) == 0 metadata_fields = [ - DataRowMetadataField(name='captureDateTime', value=datetime_str) + DataRowMetadataField(name="captureDateTime", value=datetime_str) ] - data_row = dataset.create_data_row(row_data=image_url, - metadata_fields=metadata_fields) + data_row = dataset.create_data_row( + row_data=image_url, metadata_fields=metadata_fields + ) retrieved_data_row = client.get_data_row(data_row.uid) - assert f"{retrieved_data_row.metadata[0].value}" == "2011-11-04 00:05:23+00:00" + assert ( + f"{retrieved_data_row.metadata[0].value}" == "2011-11-04 00:05:23+00:00" + ) diff --git a/libs/labelbox/tests/data/test_prefetch_generator.py b/libs/labelbox/tests/data/test_prefetch_generator.py index 2738f3640..b90074a9d 100644 --- a/libs/labelbox/tests/data/test_prefetch_generator.py +++ b/libs/labelbox/tests/data/test_prefetch_generator.py @@ -4,13 +4,12 @@ class ChildClassGenerator(PrefetchGenerator): - def __init__(self, examples, num_executors=1): super().__init__(data=examples, num_executors=num_executors) def _process(self, value): num = random() - if num < .2: + if num < 0.2: raise ValueError("Randomized value error") return value diff --git a/libs/labelbox/tests/integration/conftest.py b/libs/labelbox/tests/integration/conftest.py index 5b1f9aa9a..d37287fe8 100644 --- a/libs/labelbox/tests/integration/conftest.py +++ b/libs/labelbox/tests/integration/conftest.py @@ -17,7 +17,15 @@ from labelbox import Dataset, DataRow from labelbox import LabelingFrontend -from labelbox import OntologyBuilder, Tool, Option, Classification, MediaType, PromptResponseClassification, ResponseOption +from labelbox import ( + OntologyBuilder, + Tool, + Option, + Classification, + MediaType, + PromptResponseClassification, + ResponseOption, +) from labelbox.orm import query from labelbox.pagination import PaginatedCollection from labelbox.schema.annotation_import import LabelImport @@ -46,9 +54,10 @@ def project_based_user(client, rand_gen): newUserId } } - """ % (email, str(client.get_roles()['NONE'].uid)) - user_id = client.execute( - query_str)['addMembersToOrganization'][0]['newUserId'] + """ % (email, str(client.get_roles()["NONE"].uid)) + user_id = client.execute(query_str)["addMembersToOrganization"][0][ + "newUserId" + ] assert user_id is not None, "Unable to add user with old mutation" user = client._get_single(User, user_id) yield user @@ -58,9 +67,12 @@ def project_based_user(client, rand_gen): @pytest.fixture def project_pack(client): projects = [ - client.create_project(name=f"user-proj-{idx}", - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) for idx in range(2) + client.create_project( + name=f"user-proj-{idx}", + queue_mode=QueueMode.Batch, + media_type=MediaType.Image, + ) + for idx in range(2) ] yield projects for proj in projects: @@ -71,15 +83,18 @@ def project_pack(client): def project_with_empty_ontology(project): editor = list( project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] + where=LabelingFrontend.name == "editor" + ) + )[0] empty_ontology = {"tools": [], "classifications": []} project.setup(editor, empty_ontology) yield project @pytest.fixture -def configured_project(project_with_empty_ontology, initial_dataset, rand_gen, - image_url): +def configured_project( + project_with_empty_ontology, initial_dataset, rand_gen, image_url +): dataset = initial_dataset data_row_id = dataset.create_data_row(row_data=image_url).uid project = project_with_empty_ontology @@ -87,7 +102,7 @@ def configured_project(project_with_empty_ontology, initial_dataset, rand_gen, batch = project.create_batch( rand_gen(str), [data_row_id], # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) + 5, # priority between 1(Highest) - 5(lowest) ) project.data_row_ids = [data_row_id] @@ -97,11 +112,14 @@ def configured_project(project_with_empty_ontology, initial_dataset, rand_gen, @pytest.fixture -def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, - image_url): - project = client.create_project(name=rand_gen(str), - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) +def configured_project_with_complex_ontology( + client, initial_dataset, rand_gen, image_url +): + project = client.create_project( + name=rand_gen(str), + queue_mode=QueueMode.Batch, + media_type=MediaType.Image, + ) dataset = initial_dataset data_row = dataset.create_data_row(row_data=image_url) data_row_ids = [data_row.uid] @@ -109,13 +127,15 @@ def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, project.create_batch( rand_gen(str), data_row_ids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) + 5, # priority between 1(Highest) - 5(lowest) ) project.data_row_ids = data_row_ids editor = list( project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] + where=LabelingFrontend.name == "editor" + ) + )[0] ontology = OntologyBuilder() tools = [ @@ -123,24 +143,29 @@ def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, Tool(tool=Tool.Type.LINE, name="test-line-class"), Tool(tool=Tool.Type.POINT, name="test-point-class"), Tool(tool=Tool.Type.POLYGON, name="test-polygon-class"), - Tool(tool=Tool.Type.NER, name="test-ner-class") + Tool(tool=Tool.Type.NER, name="test-ner-class"), ] options = [ Option(value="first option answer"), Option(value="second option answer"), - Option(value="third option answer") + Option(value="third option answer"), ] classifications = [ - Classification(class_type=Classification.Type.TEXT, - name="test-text-class"), - Classification(class_type=Classification.Type.RADIO, - name="test-radio-class", - options=options), - Classification(class_type=Classification.Type.CHECKLIST, - name="test-checklist-class", - options=options) + Classification( + class_type=Classification.Type.TEXT, name="test-text-class" + ), + Classification( + class_type=Classification.Type.RADIO, + name="test-radio-class", + options=options, + ), + Classification( + class_type=Classification.Type.CHECKLIST, + name="test-checklist-class", + options=options, + ), ] for t in tools: @@ -161,19 +186,22 @@ def ontology(client): ontology_builder = OntologyBuilder( tools=[ Tool(tool=Tool.Type.BBOX, name="Box 1", color="#ff0000"), - Tool(tool=Tool.Type.BBOX, name="Box 2", color="#ff0000") + Tool(tool=Tool.Type.BBOX, name="Box 2", color="#ff0000"), ], classifications=[ - Classification(name="Root Class", - class_type=Classification.Type.RADIO, - options=[ - Option(value="1", label="Option 1"), - Option(value="2", label="Option 2") - ]) - ]) - ontology = client.create_ontology('Integration Test Ontology', - ontology_builder.asdict(), - MediaType.Image) + Classification( + name="Root Class", + class_type=Classification.Type.RADIO, + options=[ + Option(value="1", label="Option 1"), + Option(value="2", label="Option 2"), + ], + ) + ], + ) + ontology = client.create_ontology( + "Integration Test Ontology", ontology_builder.asdict(), MediaType.Image + ) yield ontology client.delete_unused_ontology(ontology.uid) @@ -191,12 +219,9 @@ def video_data(client, rand_gen, video_data_row, wait_for_data_row_processing): def create_video_data_row(rand_gen): return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4", - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4-{rand_gen(str)}", - "media_type": - "VIDEO", + "row_data": "https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4", + "global_key": f"https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4-{rand_gen(str)}", + "media_type": "VIDEO", } @@ -218,25 +243,25 @@ def video_data_row(rand_gen): class ExportV2Helpers: - @classmethod - def run_project_export_v2_task(cls, - project, - num_retries=5, - task_name=None, - filters={}, - params={}): + def run_project_export_v2_task( + cls, project, num_retries=5, task_name=None, filters={}, params={} + ): task = None - params = params if params else { - "project_details": True, - "performance_details": False, - "data_row_details": True, - "label_details": True - } - while (num_retries > 0): - task = project.export_v2(task_name=task_name, - filters=filters, - params=params) + params = ( + params + if params + else { + "project_details": True, + "performance_details": False, + "data_row_details": True, + "label_details": True, + } + ) + while num_retries > 0: + task = project.export_v2( + task_name=task_name, filters=filters, params=params + ) task.wait_till_done() assert task.status == "COMPLETE" assert task.errors is None @@ -248,21 +273,19 @@ def run_project_export_v2_task(cls, return task.result @classmethod - def run_dataset_export_v2_task(cls, - dataset, - num_retries=5, - task_name=None, - filters={}, - params={}): + def run_dataset_export_v2_task( + cls, dataset, num_retries=5, task_name=None, filters={}, params={} + ): task = None - params = params if params else { - "performance_details": False, - "label_details": True - } - while (num_retries > 0): - task = dataset.export_v2(task_name=task_name, - filters=filters, - params=params) + params = ( + params + if params + else {"performance_details": False, "label_details": True} + ) + while num_retries > 0: + task = dataset.export_v2( + task_name=task_name, filters=filters, params=params + ) task.wait_till_done() assert task.status == "COMPLETE" assert task.errors is None @@ -275,23 +298,20 @@ def run_dataset_export_v2_task(cls, return task.result @classmethod - def run_catalog_export_v2_task(cls, - client, - num_retries=5, - task_name=None, - filters={}, - params={}): + def run_catalog_export_v2_task( + cls, client, num_retries=5, task_name=None, filters={}, params={} + ): task = None - params = params if params else { - "performance_details": False, - "label_details": True - } + params = ( + params + if params + else {"performance_details": False, "label_details": True} + ) catalog = client.get_catalog() - while (num_retries > 0): - - task = catalog.export_v2(task_name=task_name, - filters=filters, - params=params) + while num_retries > 0: + task = catalog.export_v2( + task_name=task_name, filters=filters, params=params + ) task.wait_till_done() assert task.status == "COMPLETE" assert task.errors is None @@ -317,9 +337,10 @@ def big_dataset_data_row_ids(big_dataset: Dataset): yield [dr.json["data_row"]["id"] for dr in stream] -@pytest.fixture(scope='function') -def dataset_with_invalid_data_rows(unique_dataset: Dataset, - upload_invalid_data_rows_for_dataset): +@pytest.fixture(scope="function") +def dataset_with_invalid_data_rows( + unique_dataset: Dataset, upload_invalid_data_rows_for_dataset +): upload_invalid_data_rows_for_dataset(unique_dataset) yield unique_dataset @@ -327,29 +348,33 @@ def dataset_with_invalid_data_rows(unique_dataset: Dataset, @pytest.fixture def upload_invalid_data_rows_for_dataset(): - def _upload_invalid_data_rows_for_dataset(dataset: Dataset): - task = dataset.create_data_rows([ - { - "row_data": 'gs://invalid-bucket/example.png', # forbidden - "external_id": "image-without-access.jpg" - }, - ] * 2) + task = dataset.create_data_rows( + [ + { + "row_data": "gs://invalid-bucket/example.png", # forbidden + "external_id": "image-without-access.jpg", + }, + ] + * 2 + ) task.wait_till_done() return _upload_invalid_data_rows_for_dataset @pytest.fixture -def prompt_response_generation_project_with_new_dataset(client: Client, - rand_gen, request): +def prompt_response_generation_project_with_new_dataset( + client: Client, rand_gen, request +): """fixture is parametrize and needs project_type in request""" media_type = request.param prompt_response_project = client.create_prompt_response_generation_project( name=f"{media_type.value}-{rand_gen(str)}", dataset_name=f"{media_type.value}-{rand_gen(str)}", data_row_count=1, - media_type=media_type) + media_type=media_type, + ) yield prompt_response_project @@ -357,15 +382,17 @@ def prompt_response_generation_project_with_new_dataset(client: Client, @pytest.fixture -def prompt_response_generation_project_with_dataset_id(client: Client, dataset, - rand_gen, request): +def prompt_response_generation_project_with_dataset_id( + client: Client, dataset, rand_gen, request +): """fixture is parametrized and needs project_type in request""" media_type = request.param prompt_response_project = client.create_prompt_response_generation_project( name=f"{media_type.value}-{rand_gen(str)}", dataset_id=dataset.uid, data_row_count=1, - media_type=media_type) + media_type=media_type, + ) yield prompt_response_project @@ -384,10 +411,10 @@ def response_creation_project(client: Client, rand_gen): @pytest.fixture def prompt_response_features(rand_gen): - prompt_text = PromptResponseClassification( class_type=PromptResponseClassification.Type.PROMPT, - name=f"{rand_gen(str)}-prompt text") + name=f"{rand_gen(str)}-prompt text", + ) response_radio = PromptResponseClassification( class_type=PromptResponseClassification.Type.RESPONSE_RADIO, @@ -395,27 +422,33 @@ def prompt_response_features(rand_gen): options=[ ResponseOption(value=f"{rand_gen(str)}-first radio option answer"), ResponseOption(value=f"{rand_gen(str)}-second radio option answer"), - ]) + ], + ) response_checklist = PromptResponseClassification( class_type=PromptResponseClassification.Type.RESPONSE_CHECKLIST, name=f"{rand_gen(str)}-response checklist classification", options=[ ResponseOption( - value=f"{rand_gen(str)}-first checklist option answer"), + value=f"{rand_gen(str)}-first checklist option answer" + ), ResponseOption( - value=f"{rand_gen(str)}-second checklist option answer"), - ]) + value=f"{rand_gen(str)}-second checklist option answer" + ), + ], + ) response_text_with_char = PromptResponseClassification( class_type=PromptResponseClassification.Type.RESPONSE_TEXT, name=f"{rand_gen(str)}-response text with character min and max", character_min=1, - character_max=10) + character_max=10, + ) response_text = PromptResponseClassification( class_type=PromptResponseClassification.Type.RESPONSE_TEXT, - name=f"{rand_gen(str)}-response text") + name=f"{rand_gen(str)}-response text", + ) nested_response_radio = PromptResponseClassification( class_type=PromptResponseClassification.Type.RESPONSE_RADIO, @@ -425,54 +458,65 @@ def prompt_response_features(rand_gen): f"{rand_gen(str)}-first_radio_answer", options=[ PromptResponseClassification( - class_type=PromptResponseClassification.Type. - RESPONSE_RADIO, + class_type=PromptResponseClassification.Type.RESPONSE_RADIO, name=f"{rand_gen(str)}-sub_radio_question", options=[ ResponseOption( - f"{rand_gen(str)}-first_sub_radio_answer") - ]) - ]) - ]) + f"{rand_gen(str)}-first_sub_radio_answer" + ) + ], + ) + ], + ) + ], + ) yield { "prompts": [prompt_text], "responses": [ - response_text, response_radio, response_checklist, - response_text_with_char, nested_response_radio - ] + response_text, + response_radio, + response_checklist, + response_text_with_char, + nested_response_radio, + ], } @pytest.fixture -def prompt_response_ontology(client: Client, rand_gen, prompt_response_features, - request): +def prompt_response_ontology( + client: Client, rand_gen, prompt_response_features, request +): """fixture is parametrize and needs project_type in request""" project_type = request.param if project_type == MediaType.LLMPromptCreation: ontology_builder = OntologyBuilder( - tools=[], classifications=prompt_response_features["prompts"]) + tools=[], classifications=prompt_response_features["prompts"] + ) elif project_type == MediaType.LLMPromptResponseCreation: ontology_builder = OntologyBuilder( tools=[], - classifications=prompt_response_features["prompts"] + - prompt_response_features["responses"]) + classifications=prompt_response_features["prompts"] + + prompt_response_features["responses"], + ) else: ontology_builder = OntologyBuilder( - tools=[], classifications=prompt_response_features["responses"]) + tools=[], classifications=prompt_response_features["responses"] + ) ontology_name = f"prompt-response-{rand_gen(str)}" if project_type in MediaType: - ontology = client.create_ontology(ontology_name, - ontology_builder.asdict(), - media_type=project_type) + ontology = client.create_ontology( + ontology_name, ontology_builder.asdict(), media_type=project_type + ) else: ontology = client.create_ontology( ontology_name, ontology_builder.asdict(), media_type=MediaType.Text, - ontology_kind=OntologyKind.ResponseCreation) + ontology_kind=OntologyKind.ResponseCreation, + ) yield ontology featureSchemaIds = [ @@ -503,7 +547,8 @@ def feature_schema(client, point): yield created_feature_schema client.delete_unused_feature_schema( - created_feature_schema.normalized['featureSchemaId']) + created_feature_schema.normalized["featureSchemaId"] + ) @pytest.fixture @@ -511,55 +556,75 @@ def chat_evaluation_ontology(client, rand_gen): ontology_name = f"test-chat-evaluation-ontology-{rand_gen(str)}" ontology_builder = OntologyBuilder( tools=[ - Tool(tool=Tool.Type.MESSAGE_SINGLE_SELECTION, - name="model output single selection"), - Tool(tool=Tool.Type.MESSAGE_MULTI_SELECTION, - name="model output multi selection"), - Tool(tool=Tool.Type.MESSAGE_RANKING, - name="model output multi ranking"), + Tool( + tool=Tool.Type.MESSAGE_SINGLE_SELECTION, + name="model output single selection", + ), + Tool( + tool=Tool.Type.MESSAGE_MULTI_SELECTION, + name="model output multi selection", + ), + Tool( + tool=Tool.Type.MESSAGE_RANKING, + name="model output multi ranking", + ), ], classifications=[ - Classification(class_type=Classification.Type.TEXT, - name="global model output text classification", - scope=Classification.Scope.GLOBAL), - Classification(class_type=Classification.Type.RADIO, - name="global model output radio classification", - scope=Classification.Scope.GLOBAL, - options=[ - Option(value="global first option answer"), - Option(value="global second option answer"), - ]), - Classification(class_type=Classification.Type.CHECKLIST, - name="global model output checklist classification", - scope=Classification.Scope.GLOBAL, - options=[ - Option(value="global first option answer"), - Option(value="global second option answer"), - ]), - Classification(class_type=Classification.Type.TEXT, - name="index model output text classification", - scope=Classification.Scope.INDEX), - Classification(class_type=Classification.Type.RADIO, - name="index model output radio classification", - scope=Classification.Scope.INDEX, - options=[ - Option(value="index first option answer"), - Option(value="index second option answer"), - ]), - Classification(class_type=Classification.Type.CHECKLIST, - name="index model output checklist classification", - scope=Classification.Scope.INDEX, - options=[ - Option(value="index first option answer"), - Option(value="index second option answer"), - ]), - ]) + Classification( + class_type=Classification.Type.TEXT, + name="global model output text classification", + scope=Classification.Scope.GLOBAL, + ), + Classification( + class_type=Classification.Type.RADIO, + name="global model output radio classification", + scope=Classification.Scope.GLOBAL, + options=[ + Option(value="global first option answer"), + Option(value="global second option answer"), + ], + ), + Classification( + class_type=Classification.Type.CHECKLIST, + name="global model output checklist classification", + scope=Classification.Scope.GLOBAL, + options=[ + Option(value="global first option answer"), + Option(value="global second option answer"), + ], + ), + Classification( + class_type=Classification.Type.TEXT, + name="index model output text classification", + scope=Classification.Scope.INDEX, + ), + Classification( + class_type=Classification.Type.RADIO, + name="index model output radio classification", + scope=Classification.Scope.INDEX, + options=[ + Option(value="index first option answer"), + Option(value="index second option answer"), + ], + ), + Classification( + class_type=Classification.Type.CHECKLIST, + name="index model output checklist classification", + scope=Classification.Scope.INDEX, + options=[ + Option(value="index first option answer"), + Option(value="index second option answer"), + ], + ), + ], + ) ontology = client.create_ontology( ontology_name, ontology_builder.asdict(), media_type=MediaType.Conversational, - ontology_kind=OntologyKind.ModelEvaluation) + ontology_kind=OntologyKind.ModelEvaluation, + ) yield ontology @@ -573,9 +638,9 @@ def chat_evaluation_ontology(client, rand_gen): def live_chat_evaluation_project_with_new_dataset(client, rand_gen): project_name = f"test-model-evaluation-project-{rand_gen(str)}" dataset_name = f"test-model-evaluation-dataset-{rand_gen(str)}" - project = client.create_model_evaluation_project(name=project_name, - dataset_name=dataset_name, - data_row_count=1) + project = client.create_model_evaluation_project( + name=project_name, dataset_name=dataset_name, data_row_count=1 + ) yield project @@ -596,9 +661,9 @@ def offline_chat_evaluation_project(client, rand_gen): def chat_evaluation_project_append_to_dataset(client, dataset, rand_gen): project_name = f"test-model-evaluation-project-{rand_gen(str)}" dataset_id = dataset.uid - project = client.create_model_evaluation_project(name=project_name, - dataset_id=dataset_id, - data_row_count=1) + project = client.create_model_evaluation_project( + name=project_name, dataset_id=dataset_id, data_row_count=1 + ) yield project @@ -613,106 +678,102 @@ def offline_conversational_data_row(initial_dataset): "actors": { "clxhs9wk000013b6w7imiz0h8": { "role": "human", - "metadata": { - "name": "User" - } + "metadata": {"name": "User"}, }, "clxhsc6xb00013b6w1awh579j": { "role": "model", "metadata": { "modelConfigId": "5a50d319-56bd-405d-87bb-4442daea0d0f" - } + }, }, "clxhsc6xb00023b6wlp0768zs": { "role": "model", "metadata": { "modelConfigId": "1cfc833a-2684-47df-95ac-bb7d9f9e3e1f" - } - } + }, + }, }, "messages": { "clxhs9wk000023b6wrufora3k": { "actorId": "clxhs9wk000013b6w7imiz0h8", - "content": [{ - "type": "text", - "content": "Hello world" - }], - "childMessageIds": ["clxhscb4z00033b6wukpvmuol"] + "content": [{"type": "text", "content": "Hello world"}], + "childMessageIds": ["clxhscb4z00033b6wukpvmuol"], }, "clxhscb4z00033b6wukpvmuol": { "actorId": "clxhsc6xb00013b6w1awh579j", - "content": [{ - "type": - "text", - "content": - "Hello to you too! 👋 \n\nIt's great to be your guide in the digital world. What can I help you with today? 😊 \n" - }], - "childMessageIds": ["clxhu2s0900013b6wbv0ndddd"] + "content": [ + { + "type": "text", + "content": "Hello to you too! 👋 \n\nIt's great to be your guide in the digital world. What can I help you with today? 😊 \n", + } + ], + "childMessageIds": ["clxhu2s0900013b6wbv0ndddd"], }, "clxhu2s0900013b6wbv0ndddd": { - "actorId": - "clxhs9wk000013b6w7imiz0h8", - "content": [{ - "type": "text", - "content": "Lets some some multi-turn happening" - }], + "actorId": "clxhs9wk000013b6w7imiz0h8", + "content": [ + { + "type": "text", + "content": "Lets some some multi-turn happening", + } + ], "childMessageIds": [ - "clxhu4qib00023b6wuep47b1l", "clxhu4qib00033b6wf18az01q" - ] + "clxhu4qib00023b6wuep47b1l", + "clxhu4qib00033b6wf18az01q", + ], }, "clxhu4qib00023b6wuep47b1l": { "actorId": "clxhsc6xb00013b6w1awh579j", - "content": [{ - "type": - "text", - "content": - "Okay, I'm ready for some multi-turn fun! To make it interesting, how about we try building a story together? \n\n**Here's the beginning:**\n\nThe old, dusty book lay forgotten on the shelf, its leather cover cracked and faded. But as the afternoon sun slanted through the window, a single ray caught a glint of gold on the book's spine. Suddenly...\n\n**Now you tell me what happens!** What does the glint of gold turn out to be? What happens next? 🤔 \n" - }], - "childMessageIds": ["clxhu596m00043b6wvkgahcwz"] + "content": [ + { + "type": "text", + "content": "Okay, I'm ready for some multi-turn fun! To make it interesting, how about we try building a story together? \n\n**Here's the beginning:**\n\nThe old, dusty book lay forgotten on the shelf, its leather cover cracked and faded. But as the afternoon sun slanted through the window, a single ray caught a glint of gold on the book's spine. Suddenly...\n\n**Now you tell me what happens!** What does the glint of gold turn out to be? What happens next? 🤔 \n", + } + ], + "childMessageIds": ["clxhu596m00043b6wvkgahcwz"], }, "clxhu4qib00033b6wf18az01q": { "actorId": "clxhsc6xb00023b6wlp0768zs", - "content": [{ - "type": - "text", - "content": - "Okay, I'm game! Let's have a multi-turn conversation! But to get us started, give me a topic or a question to get us going! 😄 \n\nWhat's on your mind? \n" - }], - "childMessageIds": ["clxhu596m00043b6wvkgahcwz"] + "content": [ + { + "type": "text", + "content": "Okay, I'm game! Let's have a multi-turn conversation! But to get us started, give me a topic or a question to get us going! 😄 \n\nWhat's on your mind? \n", + } + ], + "childMessageIds": ["clxhu596m00043b6wvkgahcwz"], }, "clxhu596m00043b6wvkgahcwz": { - "actorId": - "clxhs9wk000013b6w7imiz0h8", - "content": [{ - "type": "text", - "content": "Can you generate an image?" - }], + "actorId": "clxhs9wk000013b6w7imiz0h8", + "content": [ + {"type": "text", "content": "Can you generate an image?"} + ], "childMessageIds": [ - "clxhu5hqk00053b6wk21u6xi7", "clxhu5hqk00063b6whl4gofx8" - ] + "clxhu5hqk00053b6wk21u6xi7", + "clxhu5hqk00063b6whl4gofx8", + ], }, "clxhu5hqk00053b6wk21u6xi7": { "actorId": "clxhsc6xb00013b6w1awh579j", - "content": [{ - "type": - "text", - "content": - "I can't generate images directly yet. I'm a text-based AI, so my superpowers are all about words! 🪄 \n\nHowever, I'd love to help you describe an image that fits our story. Tell me: \n\n* **What should the image show?** (For example: The dusty book, the glint of gold, the bookshelf, the room, etc.) \n* **What kind of feeling or atmosphere should the image have?** (Mysterious? Magical? Ominous?) \n\nThe more details you give me, the better I can describe the image! ✨ \n" - }], - "childMessageIds": [] + "content": [ + { + "type": "text", + "content": "I can't generate images directly yet. I'm a text-based AI, so my superpowers are all about words! 🪄 \n\nHowever, I'd love to help you describe an image that fits our story. Tell me: \n\n* **What should the image show?** (For example: The dusty book, the glint of gold, the bookshelf, the room, etc.) \n* **What kind of feeling or atmosphere should the image have?** (Mysterious? Magical? Ominous?) \n\nThe more details you give me, the better I can describe the image! ✨ \n", + } + ], + "childMessageIds": [], }, "clxhu5hqk00063b6whl4gofx8": { "actorId": "clxhsc6xb00023b6wlp0768zs", - "content": [{ - "type": - "text", - "content": - "I can't *actually* generate images directly. 😔 I'm primarily a text-based AI. \n\nTo help me understand what you'd like to see, tell me: \n\n* **What should be in the image?** Be specific! (e.g., \"A cat wearing a tiny hat\", \"A futuristic cityscape at sunset\")\n* **What style do you imagine?** (e.g., realistic, cartoonish, abstract)\n\nOnce you give me those details, I can try to give you a vivid description that's almost as good as seeing it! 😊 \n" - }], - "childMessageIds": [] - } + "content": [ + { + "type": "text", + "content": "I can't *actually* generate images directly. 😔 I'm primarily a text-based AI. \n\nTo help me understand what you'd like to see, tell me: \n\n* **What should be in the image?** Be specific! (e.g., \"A cat wearing a tiny hat\", \"A futuristic cityscape at sunset\")\n* **What style do you imagine?** (e.g., realistic, cartoonish, abstract)\n\nOnce you give me those details, I can try to give you a vivid description that's almost as good as seeing it! 😊 \n", + } + ], + "childMessageIds": [], + }, }, - "rootMessageIds": ["clxhs9wk000023b6wrufora3k"] + "rootMessageIds": ["clxhs9wk000023b6wrufora3k"], } convo_v2_asset = { @@ -734,10 +795,8 @@ def response_data_row(initial_dataset): @pytest.fixture() def conversation_data_row(initial_dataset, rand_gen): data = { - "row_data": - "https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json", - "global_key": - f"https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json-{rand_gen(str)}", + "row_data": "https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json", + "global_key": f"https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json-{rand_gen(str)}", } convo_asset = {"row_data": data} data_row = initial_dataset.create_data_row(convo_asset) @@ -760,16 +819,19 @@ def pytest_fixture_setup(fixturedef): pytest.report[fixturedef.argname] += exec_time -@pytest.fixture(scope='session', autouse=True) +@pytest.fixture(scope="session", autouse=True) def print_perf_summary(): yield if "FIXTURE_PROFILE" in os.environ: sorted_dict = dict( - sorted(pytest.report.items(), - key=lambda item: item[1], - reverse=True)) + sorted( + pytest.report.items(), key=lambda item: item[1], reverse=True + ) + ) num_of_entries = 10 if len(sorted_dict) >= 10 else len(sorted_dict) - slowest_fixtures = [(aaa, sorted_dict[aaa]) - for aaa in islice(sorted_dict, num_of_entries)] + slowest_fixtures = [ + (aaa, sorted_dict[aaa]) + for aaa in islice(sorted_dict, num_of_entries) + ] print("\nTop slowest fixtures:\n", slowest_fixtures, file=sys.stderr) diff --git a/libs/labelbox/tests/integration/schema/test_user_group.py b/libs/labelbox/tests/integration/schema/test_user_group.py index 27882e2d7..6aebd4e89 100644 --- a/libs/labelbox/tests/integration/schema/test_user_group.py +++ b/libs/labelbox/tests/integration/schema/test_user_group.py @@ -3,10 +3,15 @@ from uuid import uuid4 from labelbox import Client from labelbox.schema.user_group import UserGroup, UserGroupColor -from labelbox.exceptions import ResourceNotFoundError, ResourceCreationError, UnprocessableEntityError +from labelbox.exceptions import ( + ResourceNotFoundError, + ResourceCreationError, + UnprocessableEntityError, +) data = faker.Faker() + @pytest.fixture def user_group(client): group_name = data.name() @@ -141,7 +146,7 @@ def test_cannot_update_group_id(user_group): def test_get_user_groups_with_creation_deletion(client): user_group = None - try: + try: # Get all user groups user_groups = list(UserGroup(client).get_user_groups()) @@ -167,7 +172,9 @@ def test_get_user_groups_with_creation_deletion(client): user_groups_post_deletion = list(UserGroup(client).get_user_groups()) - assert len(user_groups_post_deletion) == len(user_groups_post_creation) - 1 + assert ( + len(user_groups_post_deletion) == len(user_groups_post_creation) - 1 + ) finally: if user_group: @@ -217,4 +224,5 @@ def test_throw_error_delete_user_group_no_id(user_group, client): if __name__ == "__main__": import subprocess - subprocess.call(["pytest", "-v", __file__]) \ No newline at end of file + + subprocess.call(["pytest", "-v", __file__]) diff --git a/libs/labelbox/tests/integration/test_batch.py b/libs/labelbox/tests/integration/test_batch.py index d5e3b7a0f..3f9e720a3 100644 --- a/libs/labelbox/tests/integration/test_batch.py +++ b/libs/labelbox/tests/integration/test_batch.py @@ -4,7 +4,12 @@ import pytest from labelbox import Dataset, Project -from labelbox.exceptions import ProcessingWaitTimeout, MalformedQueryException, ResourceConflict, LabelboxError +from labelbox.exceptions import ( + ProcessingWaitTimeout, + MalformedQueryException, + ResourceConflict, + LabelboxError, +) def get_data_row_ids(ds: Dataset): @@ -12,13 +17,12 @@ def get_data_row_ids(ds: Dataset): def test_create_batch(project: Project, big_dataset_data_row_ids: List[str]): - batch = project.create_batch("test-batch", - big_dataset_data_row_ids, - 3, - consensus_settings={ - 'number_of_labels': 3, - 'coverage_percentage': 0.1 - }) + batch = project.create_batch( + "test-batch", + big_dataset_data_row_ids, + 3, + consensus_settings={"number_of_labels": 3, "coverage_percentage": 0.1}, + ) assert batch.name == "test-batch" assert batch.size == len(big_dataset_data_row_ids) @@ -27,86 +31,101 @@ def test_create_batch(project: Project, big_dataset_data_row_ids: List[str]): def test_create_batch_with_invalid_data_rows_ids(project: Project): with pytest.raises(MalformedQueryException) as ex: - project.create_batch("test-batch", data_rows=['a', 'b', 'c']) - assert str( - ex) == "No valid data rows to be added from the list provided!" - - -def test_create_batch_with_the_same_name(project: Project, - small_dataset: Dataset): - batch1 = project.create_batch("batch1", - data_rows=get_data_row_ids(small_dataset)) + project.create_batch("test-batch", data_rows=["a", "b", "c"]) + assert ( + str(ex) == "No valid data rows to be added from the list provided!" + ) + + +def test_create_batch_with_the_same_name( + project: Project, small_dataset: Dataset +): + batch1 = project.create_batch( + "batch1", data_rows=get_data_row_ids(small_dataset) + ) assert batch1.name == "batch1" with pytest.raises(ResourceConflict): - project.create_batch("batch1", - data_rows=get_data_row_ids(small_dataset)) + project.create_batch( + "batch1", data_rows=get_data_row_ids(small_dataset) + ) -def test_create_batch_with_same_data_row_ids(project: Project, - small_dataset: Dataset): - batch1 = project.create_batch("batch1", - data_rows=get_data_row_ids(small_dataset)) +def test_create_batch_with_same_data_row_ids( + project: Project, small_dataset: Dataset +): + batch1 = project.create_batch( + "batch1", data_rows=get_data_row_ids(small_dataset) + ) assert batch1.name == "batch1" with pytest.raises(MalformedQueryException) as ex: - project.create_batch("batch2", - data_rows=get_data_row_ids(small_dataset)) + project.create_batch( + "batch2", data_rows=get_data_row_ids(small_dataset) + ) assert str(ex) == "No valid data rows to add to project" def test_create_batch_with_non_existent_global_keys(project: Project): with pytest.raises(MalformedQueryException) as ex: project.create_batch("batch1", global_keys=["key1"]) - assert str( - ex - ) == "Data rows with the following global keys do not exist: key1." + assert ( + str(ex) + == "Data rows with the following global keys do not exist: key1." + ) -def test_create_batch_with_string_priority(project: Project, - small_dataset: Dataset): +def test_create_batch_with_string_priority( + project: Project, small_dataset: Dataset +): with pytest.raises(LabelboxError): - project.create_batch("batch1", - data_rows=get_data_row_ids(small_dataset), - priority="abcd") + project.create_batch( + "batch1", data_rows=get_data_row_ids(small_dataset), priority="abcd" + ) -def test_create_batch_with_null_priority(project: Project, - small_dataset: Dataset): +def test_create_batch_with_null_priority( + project: Project, small_dataset: Dataset +): with pytest.raises(LabelboxError): - project.create_batch("batch1", - data_rows=get_data_row_ids(small_dataset), - priority=None) + project.create_batch( + "batch1", data_rows=get_data_row_ids(small_dataset), priority=None + ) -def test_create_batch_async(project: Project, - big_dataset_data_row_ids: List[str]): - batch = project._create_batch_async("big-batch", - big_dataset_data_row_ids, - priority=3) +def test_create_batch_async( + project: Project, big_dataset_data_row_ids: List[str] +): + batch = project._create_batch_async( + "big-batch", big_dataset_data_row_ids, priority=3 + ) assert batch.name == "big-batch" assert batch.size == len(big_dataset_data_row_ids) assert len([dr for dr in batch.failed_data_row_ids]) == 0 -def test_create_batch_with_consensus_settings(project: Project, - small_dataset: Dataset): +def test_create_batch_with_consensus_settings( + project: Project, small_dataset: Dataset +): export_task = small_dataset.export() export_task.wait_till_done() stream = export_task.get_buffered_stream() data_rows = [dr.json["data_row"]["id"] for dr in stream] consensus_settings = {"coverage_percentage": 0.1, "number_of_labels": 3} - batch = project.create_batch("batch with consensus settings", - data_rows, - 3, - consensus_settings=consensus_settings) + batch = project.create_batch( + "batch with consensus settings", + data_rows, + 3, + consensus_settings=consensus_settings, + ) assert batch.name == "batch with consensus settings" assert batch.size == len(data_rows) assert batch.consensus_settings == consensus_settings -def test_create_batch_with_data_row_class(project: Project, - small_dataset: Dataset): +def test_create_batch_with_data_row_class( + project: Project, small_dataset: Dataset +): export_task = small_dataset.export() export_task.wait_till_done() stream = export_task.get_buffered_stream() @@ -121,11 +140,11 @@ def test_archive_batch(project: Project, small_dataset: Dataset): export_task.wait_till_done() stream = export_task.get_buffered_stream() data_rows = [dr.json["data_row"]["id"] for dr in stream] - + batch = project.create_batch("batch to archive", data_rows) batch.remove_queued_data_rows() overview = project.get_overview() - + assert overview.to_label == 0 @@ -145,8 +164,9 @@ def test_batch_project(project: Project, small_dataset: Dataset): export_task.wait_till_done() stream = export_task.get_buffered_stream() data_rows = [dr.json["data_row"]["id"] for dr in stream] - batch = project.create_batch("batch to test project relationship", - data_rows) + batch = project.create_batch( + "batch to test project relationship", data_rows + ) project_from_batch = batch.project() @@ -155,8 +175,10 @@ def test_batch_project(project: Project, small_dataset: Dataset): def test_batch_creation_for_data_rows_with_issues( - project: Project, small_dataset: Dataset, - dataset_with_invalid_data_rows: Dataset): + project: Project, + small_dataset: Dataset, + dataset_with_invalid_data_rows: Dataset, +): """ Create a batch containing both valid and invalid data rows """ @@ -167,8 +189,9 @@ def test_batch_creation_for_data_rows_with_issues( data_rows_to_add = valid_data_rows + invalid_data_rows assert len(data_rows_to_add) == 4 - batch = project.create_batch("batch to test failed data rows", - data_rows_to_add) + batch = project.create_batch( + "batch to test failed data rows", data_rows_to_add + ) failed_data_row_ids = [x for x in batch.failed_data_row_ids] assert len(failed_data_row_ids) == 2 @@ -178,8 +201,11 @@ def test_batch_creation_for_data_rows_with_issues( def test_batch_creation_with_processing_timeout( - project: Project, small_dataset: Dataset, unique_dataset: Dataset, - upload_invalid_data_rows_for_dataset): + project: Project, + small_dataset: Dataset, + unique_dataset: Dataset, + upload_invalid_data_rows_for_dataset, +): """ Create a batch with zero wait time, this means that the waiting logic will throw exception immediately """ @@ -202,15 +228,16 @@ def test_batch_creation_with_processing_timeout( @pytest.mark.export_v1("export_v1 test remove later") -def test_export_data_rows(project: Project, dataset: Dataset, image_url: str, - external_id: str): +def test_export_data_rows( + project: Project, dataset: Dataset, image_url: str, external_id: str +): n_data_rows = 2 - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": external_id - }, - ] * n_data_rows) + task = dataset.create_data_rows( + [ + {"row_data": image_url, "external_id": external_id}, + ] + * n_data_rows + ) task.wait_till_done() data_rows = [dr.uid for dr in list(dataset.export_data_rows())] @@ -227,10 +254,10 @@ def test_list_all_batches(project: Project, client, image_url: str): Test to verify that we can retrieve all available batches in the project. """ # Data to use - img_assets = [{ - "row_data": image_url, - "external_id": str(uuid4()) - } for asset in range(0, 2)] + img_assets = [ + {"row_data": image_url, "external_id": str(uuid4())} + for asset in range(0, 2) + ] data = [img_assets for _ in range(0, 2)] # Setup @@ -245,8 +272,9 @@ def test_list_all_batches(project: Project, client, image_url: str): for dataset in datasets: data_row_ids = get_data_row_ids(dataset) - new_batch = project.create_batch(name=str(uuid4()), - data_rows=data_row_ids) + new_batch = project.create_batch( + name=str(uuid4()), data_rows=data_row_ids + ) batches.append(new_batch) # Test @@ -269,7 +297,8 @@ def test_list_project_batches_with_no_batches(project: Project): @pytest.mark.skip( reason="Test cannot be used effectively with MAL/LabelImport. \ -Fix/Unskip after resolving deletion with MAL/LabelImport") +Fix/Unskip after resolving deletion with MAL/LabelImport" +) def test_delete_labels(project, small_dataset): export_task = small_dataset.export() export_task.wait_till_done() @@ -280,14 +309,16 @@ def test_delete_labels(project, small_dataset): @pytest.mark.skip( reason="Test cannot be used effectively with MAL/LabelImport. \ -Fix/Unskip after resolving deletion with MAL/LabelImport") +Fix/Unskip after resolving deletion with MAL/LabelImport" +) def test_delete_labels_with_templates(project: Project, small_dataset: Dataset): export_task = small_dataset.export() export_task.wait_till_done() stream = export_task.get_buffered_stream() data_rows = [dr.json["data_row"]["id"] for dr in stream] - batch = project.create_batch("batch to delete labels w templates", - data_rows) + batch = project.create_batch( + "batch to delete labels w templates", data_rows + ) export_task = project.export(filters={"batch_ids": [batch.uid]}) export_task.wait_till_done() diff --git a/libs/labelbox/tests/integration/test_batches.py b/libs/labelbox/tests/integration/test_batches.py index 5c24a65f0..cabae4053 100644 --- a/libs/labelbox/tests/integration/test_batches.py +++ b/libs/labelbox/tests/integration/test_batches.py @@ -6,9 +6,9 @@ def test_create_batches(project: Project, big_dataset_data_row_ids: List[str]): - task = project.create_batches("test-batch", - big_dataset_data_row_ids, - priority=3) + task = project.create_batches( + "test-batch", big_dataset_data_row_ids, priority=3 + ) task.wait_till_done() assert task.errors() is None @@ -26,9 +26,9 @@ def test_create_batches_from_dataset(project: Project, big_dataset: Dataset): data_rows = [dr.json["data_row"]["id"] for dr in stream] project._wait_until_data_rows_are_processed(data_rows, [], 300) - task = project.create_batches_from_dataset("test-batch", - big_dataset.uid, - priority=3) + task = project.create_batches_from_dataset( + "test-batch", big_dataset.uid, priority=3 + ) task.wait_till_done() assert task.errors() is None diff --git a/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py b/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py index aafcddbcc..47e39e2cf 100644 --- a/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py +++ b/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py @@ -7,9 +7,12 @@ def test_create_chat_evaluation_ontology_project( - client, chat_evaluation_ontology, - live_chat_evaluation_project_with_new_dataset, - offline_conversational_data_row, rand_gen): + client, + chat_evaluation_ontology, + live_chat_evaluation_project_with_new_dataset, + offline_conversational_data_row, + rand_gen, +): ontology = chat_evaluation_ontology # here we are essentially testing the ontology creation which is a fixture @@ -20,7 +23,7 @@ def test_create_chat_evaluation_ontology_project( assert tool.schema_id assert tool.feature_schema_id - assert (len(ontology.classifications()) == 6) + assert len(ontology.classifications()) == 6 for classification in ontology.classifications(): assert classification.schema_id assert classification.feature_schema_id @@ -34,29 +37,32 @@ def test_create_chat_evaluation_ontology_project( assert project.ontology().name == ontology.name with pytest.raises( - ValueError, - match="Cannot create batches for auto data generation projects"): + ValueError, + match="Cannot create batches for auto data generation projects", + ): project.create_batch( rand_gen(str), [offline_conversational_data_row.uid], # sample of data row objects ) with pytest.raises( - ValueError, - match="Cannot create batches for auto data generation projects"): - with patch('labelbox.schema.project.MAX_SYNC_BATCH_ROW_COUNT', - new=0): # force to async - + ValueError, + match="Cannot create batches for auto data generation projects", + ): + with patch( + "labelbox.schema.project.MAX_SYNC_BATCH_ROW_COUNT", new=0 + ): # force to async project.create_batch( rand_gen(str), - [offline_conversational_data_row.uid + [ + offline_conversational_data_row.uid ], # sample of data row objects ) def test_create_chat_evaluation_ontology_project_existing_dataset( - client, chat_evaluation_ontology, - chat_evaluation_project_append_to_dataset): + client, chat_evaluation_ontology, chat_evaluation_project_append_to_dataset +): ontology = chat_evaluation_ontology project = chat_evaluation_project_append_to_dataset @@ -69,31 +75,35 @@ def test_create_chat_evaluation_ontology_project_existing_dataset( @pytest.fixture def tools_json(): - tools = [{ - 'tool': 'message-single-selection', - 'name': 'model output single selection', - 'required': False, - 'color': '#ff0000', - 'classifications': [], - 'schemaNodeId': None, - 'featureSchemaId': None - }, { - 'tool': 'message-multi-selection', - 'name': 'model output multi selection', - 'required': False, - 'color': '#00ff00', - 'classifications': [], - 'schemaNodeId': None, - 'featureSchemaId': None - }, { - 'tool': 'message-ranking', - 'name': 'model output multi ranking', - 'required': False, - 'color': '#0000ff', - 'classifications': [], - 'schemaNodeId': None, - 'featureSchemaId': None - }] + tools = [ + { + "tool": "message-single-selection", + "name": "model output single selection", + "required": False, + "color": "#ff0000", + "classifications": [], + "schemaNodeId": None, + "featureSchemaId": None, + }, + { + "tool": "message-multi-selection", + "name": "model output multi selection", + "required": False, + "color": "#00ff00", + "classifications": [], + "schemaNodeId": None, + "featureSchemaId": None, + }, + { + "tool": "message-ranking", + "name": "model output multi ranking", + "required": False, + "color": "#0000ff", + "classifications": [], + "schemaNodeId": None, + "featureSchemaId": None, + }, + ] return tools @@ -124,19 +134,21 @@ def ontology_from_feature_ids(client, features_from_json): client.delete_unused_ontology(ontology.uid) -def test_ontology_create_feature_schema(ontology_from_feature_ids, - features_from_json, tools_json): +def test_ontology_create_feature_schema( + ontology_from_feature_ids, features_from_json, tools_json +): created_ontology = ontology_from_feature_ids feature_schema_ids = {f.uid for f in features_from_json} - tools_normalized = created_ontology.normalized['tools'] + tools_normalized = created_ontology.normalized["tools"] tools = tools_json for tool in tools: generated_tool = next( - t for t in tools_normalized if t['name'] == tool['name']) - assert generated_tool['schemaNodeId'] is not None - assert generated_tool['featureSchemaId'] in feature_schema_ids - assert generated_tool['tool'] == tool['tool'] - assert generated_tool['name'] == tool['name'] - assert generated_tool['required'] == tool['required'] - assert generated_tool['color'] == tool['color'] + t for t in tools_normalized if t["name"] == tool["name"] + ) + assert generated_tool["schemaNodeId"] is not None + assert generated_tool["featureSchemaId"] in feature_schema_ids + assert generated_tool["tool"] == tool["tool"] + assert generated_tool["name"] == tool["name"] + assert generated_tool["required"] == tool["required"] + assert generated_tool["color"] == tool["color"] diff --git a/libs/labelbox/tests/integration/test_client_errors.py b/libs/labelbox/tests/integration/test_client_errors.py index 411b9e3b0..64b8fb626 100644 --- a/libs/labelbox/tests/integration/test_client_errors.py +++ b/libs/labelbox/tests/integration/test_client_errors.py @@ -40,7 +40,7 @@ def test_syntax_error(client): def test_semantic_error(client): with pytest.raises(labelbox.exceptions.InvalidQueryError) as excinfo: client.execute("query {bbb {id}}", check_naming=False) - assert excinfo.value.message.startswith("Cannot query field \"bbb\"") + assert excinfo.value.message.startswith('Cannot query field "bbb"') def test_timeout_error(client, project): @@ -59,8 +59,9 @@ def test_timeout_error(client, project): def test_query_complexity_error(client): with pytest.raises(labelbox.exceptions.ValidationFailedError) as excinfo: - client.execute("{projects {datasets {dataRows {labels {id}}}}}", - check_naming=False) + client.execute( + "{projects {datasets {dataRows {labels {id}}}}}", check_naming=False + ) assert excinfo.value.message == "Query complexity limit exceeded" @@ -70,8 +71,9 @@ def test_resource_not_found_error(client): def test_network_error(client): - client = labelbox.client.Client(api_key=client.api_key, - endpoint="not_a_valid_URL") + client = labelbox.client.Client( + api_key=client.api_key, endpoint="not_a_valid_URL" + ) with pytest.raises(labelbox.exceptions.NetworkError) as excinfo: client.create_project(name="Project name") @@ -103,7 +105,6 @@ def test_invalid_attribute_error( @pytest.mark.skip("timeouts cause failure before rate limit") def test_api_limit_error(client): - def get(arg): try: return client.get_user() diff --git a/libs/labelbox/tests/integration/test_data_row_delete_metadata.py b/libs/labelbox/tests/integration/test_data_row_delete_metadata.py index 8674beb33..2df860181 100644 --- a/libs/labelbox/tests/integration/test_data_row_delete_metadata.py +++ b/libs/labelbox/tests/integration/test_data_row_delete_metadata.py @@ -5,7 +5,12 @@ from labelbox import DataRow, Dataset, Client, DataRowMetadataOntology from labelbox.exceptions import MalformedQueryException -from labelbox.schema.data_row_metadata import DataRowMetadataField, DataRowMetadata, DataRowMetadataKind, DeleteDataRowMetadata +from labelbox.schema.data_row_metadata import ( + DataRowMetadataField, + DataRowMetadata, + DataRowMetadataKind, + DeleteDataRowMetadata, +) from labelbox.schema.identifiable import GlobalKey, UniqueId INVALID_SCHEMA_ID = "1" * 25 @@ -16,13 +21,13 @@ TEST_SPLIT_ID = "cko8scbz70005h2dkastwhgqt" TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh" CAPTURE_DT_SCHEMA_ID = "cko8sdzv70006h2dk8jg64zvb" -CUSTOM_TEXT_SCHEMA_NAME = 'custom_text' +CUSTOM_TEXT_SCHEMA_NAME = "custom_text" FAKE_NUMBER_FIELD = { "id": FAKE_SCHEMA_ID, "name": "number", - "kind": 'CustomMetadataNumber', - "reserved": False + "kind": "CustomMetadataNumber", + "reserved": False, } @@ -42,13 +47,16 @@ def mdo(client: Client): @pytest.fixture def big_dataset(dataset: Dataset, image_url): - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image", - "global_key": str(uuid.uuid4()) - }, - ] * 5) + task = dataset.create_data_rows( + [ + { + "row_data": image_url, + "external_id": "my-image", + "global_key": str(uuid.uuid4()), + }, + ] + * 5 + ) task.wait_till_done() yield dataset @@ -62,11 +70,13 @@ def make_metadata(dr_id: str = None, gk: str = None) -> DataRowMetadata: global_key=gk, data_row_id=dr_id, fields=[ - DataRowMetadataField(schema_id=SPLIT_SCHEMA_ID, - value=TEST_SPLIT_ID), + DataRowMetadataField( + schema_id=SPLIT_SCHEMA_ID, value=TEST_SPLIT_ID + ), DataRowMetadataField(schema_id=CAPTURE_DT_SCHEMA_ID, value=time), DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value=msg), - ]) + ], + ) return metadata @@ -74,15 +84,14 @@ def make_named_metadata(dr_id) -> DataRowMetadata: msg = "A message" time = datetime.now(timezone.utc) - metadata = DataRowMetadata(data_row_id=dr_id, - fields=[ - DataRowMetadataField(name='split', - value=TEST_SPLIT_ID), - DataRowMetadataField(name='captureDateTime', - value=time), - DataRowMetadataField( - name=CUSTOM_TEXT_SCHEMA_NAME, value=msg), - ]) + metadata = DataRowMetadata( + data_row_id=dr_id, + fields=[ + DataRowMetadataField(name="split", value=TEST_SPLIT_ID), + DataRowMetadataField(name="captureDateTime", value=time), + DataRowMetadataField(name=CUSTOM_TEXT_SCHEMA_NAME, value=msg), + ], + ) return metadata @@ -94,9 +103,11 @@ def test_bulk_delete_datarow_metadata(data_row, mdo): assert len(mdo.bulk_export([data_row.uid])[0].fields) upload_ids = [m.schema_id for m in metadata.fields[:-2]] mdo.bulk_delete( - [DeleteDataRowMetadata(data_row_id=data_row.uid, fields=upload_ids)]) + [DeleteDataRowMetadata(data_row_id=data_row.uid, fields=upload_ids)] + ) remaining_ids = set( - [f.schema_id for f in mdo.bulk_export([data_row.uid])[0].fields]) + [f.schema_id for f in mdo.bulk_export([data_row.uid])[0].fields] + ) assert not len(remaining_ids.intersection(set(upload_ids))) @@ -116,43 +127,55 @@ def data_row_id_as_str(data_row): @pytest.mark.parametrize( - 'data_row_for_delete', - ['data_row_id_as_str', 'data_row_unique_id', 'data_row_global_key']) -def test_bulk_delete_datarow_metadata(data_row_for_delete, data_row, mdo, - request): + "data_row_for_delete", + ["data_row_id_as_str", "data_row_unique_id", "data_row_global_key"], +) +def test_bulk_delete_datarow_metadata( + data_row_for_delete, data_row, mdo, request +): """test bulk deletes for all fields""" metadata = make_metadata(data_row.uid) mdo.bulk_upsert([metadata]) assert len(mdo.bulk_export([data_row.uid])[0].fields) upload_ids = [m.schema_id for m in metadata.fields[:-2]] - mdo.bulk_delete([ - DeleteDataRowMetadata( - data_row_id=request.getfixturevalue(data_row_for_delete), - fields=upload_ids) - ]) + mdo.bulk_delete( + [ + DeleteDataRowMetadata( + data_row_id=request.getfixturevalue(data_row_for_delete), + fields=upload_ids, + ) + ] + ) remaining_ids = set( - [f.schema_id for f in mdo.bulk_export([data_row.uid])[0].fields]) + [f.schema_id for f in mdo.bulk_export([data_row.uid])[0].fields] + ) assert not len(remaining_ids.intersection(set(upload_ids))) @pytest.mark.parametrize( - 'data_row_for_delete', - ['data_row_id_as_str', 'data_row_unique_id', 'data_row_global_key']) -def test_bulk_partial_delete_datarow_metadata(data_row_for_delete, data_row, - mdo, request): + "data_row_for_delete", + ["data_row_id_as_str", "data_row_unique_id", "data_row_global_key"], +) +def test_bulk_partial_delete_datarow_metadata( + data_row_for_delete, data_row, mdo, request +): """Delete a single from metadata""" n_fields = len(mdo.bulk_export([data_row.uid])[0].fields) metadata = make_metadata(data_row.uid) mdo.bulk_upsert([metadata]) - assert len(mdo.bulk_export( - [data_row.uid])[0].fields) == (n_fields + len(metadata.fields)) + assert len(mdo.bulk_export([data_row.uid])[0].fields) == ( + n_fields + len(metadata.fields) + ) - mdo.bulk_delete([ - DeleteDataRowMetadata( - data_row_id=request.getfixturevalue(data_row_for_delete), - fields=[TEXT_SCHEMA_ID]) - ]) + mdo.bulk_delete( + [ + DeleteDataRowMetadata( + data_row_id=request.getfixturevalue(data_row_for_delete), + fields=[TEXT_SCHEMA_ID], + ) + ] + ) fields = [f for f in mdo.bulk_export([data_row.uid])[0].fields] assert len(fields) == (len(metadata.fields) - 1) @@ -166,7 +189,9 @@ def data_row_unique_ids(big_dataset): deletes.append( DeleteDataRowMetadata( data_row_id=UniqueId(data_row_id), - fields=[SPLIT_SCHEMA_ID, CAPTURE_DT_SCHEMA_ID])) + fields=[SPLIT_SCHEMA_ID, CAPTURE_DT_SCHEMA_ID], + ) + ) return deletes @@ -179,7 +204,9 @@ def data_row_ids_as_str(big_dataset): deletes.append( DeleteDataRowMetadata( data_row_id=data_row_id, - fields=[SPLIT_SCHEMA_ID, CAPTURE_DT_SCHEMA_ID])) + fields=[SPLIT_SCHEMA_ID, CAPTURE_DT_SCHEMA_ID], + ) + ) return deletes @@ -192,26 +219,35 @@ def data_row_global_keys(big_dataset): deletes.append( DeleteDataRowMetadata( data_row_id=GlobalKey(data_row_id), - fields=[SPLIT_SCHEMA_ID, CAPTURE_DT_SCHEMA_ID])) + fields=[SPLIT_SCHEMA_ID, CAPTURE_DT_SCHEMA_ID], + ) + ) return deletes @pytest.mark.parametrize( - 'data_rows_for_delete', - ['data_row_ids_as_str', 'data_row_unique_ids', 'data_row_global_keys']) -def test_large_bulk_delete_datarow_metadata(data_rows_for_delete, big_dataset, - mdo, request): + "data_rows_for_delete", + ["data_row_ids_as_str", "data_row_unique_ids", "data_row_global_keys"], +) +def test_large_bulk_delete_datarow_metadata( + data_rows_for_delete, big_dataset, mdo, request +): metadata = [] data_row_ids = [dr.uid for dr in big_dataset.data_rows()] for data_row_id in data_row_ids: metadata.append( - DataRowMetadata(data_row_id=data_row_id, - fields=[ - DataRowMetadataField(schema_id=SPLIT_SCHEMA_ID, - value=TEST_SPLIT_ID), - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, - value="test-message") - ])) + DataRowMetadata( + data_row_id=data_row_id, + fields=[ + DataRowMetadataField( + schema_id=SPLIT_SCHEMA_ID, value=TEST_SPLIT_ID + ), + DataRowMetadataField( + schema_id=TEXT_SCHEMA_ID, value="test-message" + ), + ], + ) + ) errors = mdo.bulk_upsert(metadata) assert len(errors) == 0 @@ -221,7 +257,7 @@ def test_large_bulk_delete_datarow_metadata(data_rows_for_delete, big_dataset, assert len(errors) == len(data_row_ids) for error in errors: assert error.fields == [CAPTURE_DT_SCHEMA_ID] - assert error.error == 'Schema did not exist' + assert error.error == "Schema did not exist" for data_row_id in data_row_ids: fields = [f for f in mdo.bulk_export([data_row_id])[0].fields] @@ -230,10 +266,15 @@ def test_large_bulk_delete_datarow_metadata(data_rows_for_delete, big_dataset, @pytest.mark.parametrize( - 'data_row_for_delete', - ['data_row_id_as_str', 'data_row_unique_id', 'data_row_global_key']) -def test_bulk_delete_datarow_enum_metadata(data_row_for_delete, - data_row: DataRow, mdo: DataRowMetadataOntology, request): + "data_row_for_delete", + ["data_row_id_as_str", "data_row_unique_id", "data_row_global_key"], +) +def test_bulk_delete_datarow_enum_metadata( + data_row_for_delete, + data_row: DataRow, + mdo: DataRowMetadataOntology, + request, +): """test bulk deletes for non non fields""" metadata = make_metadata(data_row.uid) metadata.fields = [ @@ -243,28 +284,39 @@ def test_bulk_delete_datarow_enum_metadata(data_row_for_delete, exported = mdo.bulk_export([data_row.uid])[0].fields assert len(exported) == len( - set([x.schema_id for x in metadata.fields] + - [x.schema_id for x in exported])) - - mdo.bulk_delete([ - DeleteDataRowMetadata( - data_row_id=request.getfixturevalue(data_row_for_delete), - fields=[SPLIT_SCHEMA_ID]) - ]) + set( + [x.schema_id for x in metadata.fields] + + [x.schema_id for x in exported] + ) + ) + + mdo.bulk_delete( + [ + DeleteDataRowMetadata( + data_row_id=request.getfixturevalue(data_row_for_delete), + fields=[SPLIT_SCHEMA_ID], + ) + ] + ) exported = mdo.bulk_export([data_row.uid])[0].fields assert len(exported) == 0 @pytest.mark.parametrize( - 'data_row_for_delete', - ['data_row_id_as_str', 'data_row_unique_id', 'data_row_global_key']) -def test_delete_non_existent_schema_id(data_row_for_delete, data_row, mdo, - request): - res = mdo.bulk_delete([ - DeleteDataRowMetadata( - data_row_id=request.getfixturevalue(data_row_for_delete), - fields=[SPLIT_SCHEMA_ID]) - ]) + "data_row_for_delete", + ["data_row_id_as_str", "data_row_unique_id", "data_row_global_key"], +) +def test_delete_non_existent_schema_id( + data_row_for_delete, data_row, mdo, request +): + res = mdo.bulk_delete( + [ + DeleteDataRowMetadata( + data_row_id=request.getfixturevalue(data_row_for_delete), + fields=[SPLIT_SCHEMA_ID], + ) + ] + ) assert len(res) == 1 assert res[0].fields == [SPLIT_SCHEMA_ID] - assert res[0].error == 'Schema did not exist' + assert res[0].error == "Schema did not exist" diff --git a/libs/labelbox/tests/integration/test_data_rows.py b/libs/labelbox/tests/integration/test_data_rows.py index 454d55b87..7f69c2995 100644 --- a/libs/labelbox/tests/integration/test_data_rows.py +++ b/libs/labelbox/tests/integration/test_data_rows.py @@ -10,16 +10,26 @@ from labelbox.schema.media_type import MediaType from labelbox import DataRow, AssetAttachment -from labelbox.exceptions import MalformedQueryException, ResourceCreationError, InvalidQueryError +from labelbox.exceptions import ( + MalformedQueryException, + ResourceCreationError, + InvalidQueryError, +) from labelbox.schema.task import Task, DataUpsertTask -from labelbox.schema.data_row_metadata import DataRowMetadataField, DataRowMetadataKind +from labelbox.schema.data_row_metadata import ( + DataRowMetadataField, + DataRowMetadataKind, +) SPLIT_SCHEMA_ID = "cko8sbczn0002h2dkdaxb5kal" TEST_SPLIT_ID = "cko8scbz70005h2dkastwhgqt" TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh" CAPTURE_DT_SCHEMA_ID = "cko8sdzv70006h2dk8jg64zvb" EXPECTED_METADATA_SCHEMA_IDS = [ - SPLIT_SCHEMA_ID, TEST_SPLIT_ID, TEXT_SCHEMA_ID, CAPTURE_DT_SCHEMA_ID + SPLIT_SCHEMA_ID, + TEST_SPLIT_ID, + TEXT_SCHEMA_ID, + CAPTURE_DT_SCHEMA_ID, ].sort() CUSTOM_TEXT_SCHEMA_NAME = "custom_text" @@ -40,20 +50,19 @@ def mdo(client): @pytest.fixture def conversational_content(): return { - 'row_data': { - "messages": [{ - "messageId": "message-0", - "timestampUsec": 1530718491, - "content": "I love iphone! i just bought new iphone! 🥰 📲", - "user": { - "userId": "Bot 002", - "name": "Bot" - }, - "align": "left", - "canLabel": False - }], + "row_data": { + "messages": [ + { + "messageId": "message-0", + "timestampUsec": 1530718491, + "content": "I love iphone! i just bought new iphone! 🥰 📲", + "user": {"userId": "Bot 002", "name": "Bot"}, + "align": "left", + "canLabel": False, + } + ], "version": 1, - "type": "application/vnd.labelbox.conversational" + "type": "application/vnd.labelbox.conversational", } } @@ -62,27 +71,24 @@ def conversational_content(): def tile_content(): return { "row_data": { - "tileLayerUrl": - "https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/{z}/{x}/{y}.png", - "bounds": [[19.405662413477728, -99.21052827588443], - [19.400498983095076, -99.20534818927473]], - "minZoom": - 12, - "maxZoom": - 20, - "epsg": - "EPSG4326", - "alternativeLayers": [{ - "tileLayerUrl": - "https://api.mapbox.com/styles/v1/mapbox/satellite-streets-v11/tiles/{z}/{x}/{y}?access_token=pk.eyJ1IjoibWFwYm94IiwiYSI6ImNpejY4NXVycTA2emYycXBndHRqcmZ3N3gifQ.rJcFIG214AriISLbB6B5aw", - "name": - "Satellite" - }, { - "tileLayerUrl": - "https://api.mapbox.com/styles/v1/mapbox/navigation-guidance-night-v4/tiles/{z}/{x}/{y}?access_token=pk.eyJ1IjoibWFwYm94IiwiYSI6ImNpejY4NXVycTA2emYycXBndHRqcmZ3N3gifQ.rJcFIG214AriISLbB6B5aw", - "name": - "Guidance" - }] + "tileLayerUrl": "https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/{z}/{x}/{y}.png", + "bounds": [ + [19.405662413477728, -99.21052827588443], + [19.400498983095076, -99.20534818927473], + ], + "minZoom": 12, + "maxZoom": 20, + "epsg": "EPSG4326", + "alternativeLayers": [ + { + "tileLayerUrl": "https://api.mapbox.com/styles/v1/mapbox/satellite-streets-v11/tiles/{z}/{x}/{y}?access_token=pk.eyJ1IjoibWFwYm94IiwiYSI6ImNpejY4NXVycTA2emYycXBndHRqcmZ3N3gifQ.rJcFIG214AriISLbB6B5aw", + "name": "Satellite", + }, + { + "tileLayerUrl": "https://api.mapbox.com/styles/v1/mapbox/navigation-guidance-night-v4/tiles/{z}/{x}/{y}?access_token=pk.eyJ1IjoibWFwYm94IiwiYSI6ImNpejY4NXVycTA2emYycXBndHRqcmZ3N3gifQ.rJcFIG214AriISLbB6B5aw", + "name": "Guidance", + }, + ], } } @@ -103,16 +109,11 @@ def make_metadata_fields_dict(): msg = "A message" time = datetime.utcnow() - fields = [{ - "schema_id": SPLIT_SCHEMA_ID, - "value": TEST_SPLIT_ID - }, { - "schema_id": CAPTURE_DT_SCHEMA_ID, - "value": time - }, { - "schema_id": TEXT_SCHEMA_ID, - "value": msg - }] + fields = [ + {"schema_id": SPLIT_SCHEMA_ID, "value": TEST_SPLIT_ID}, + {"schema_id": CAPTURE_DT_SCHEMA_ID, "value": time}, + {"schema_id": TEXT_SCHEMA_ID, "value": msg}, + ] return fields @@ -133,9 +134,9 @@ def test_create_invalid_aws_data_row(dataset, client): assert "s3" in exc.value.message with pytest.raises(InvalidQueryError) as exc: - dataset.create_data_rows([{ - "row_data": "s3://labelbox-public-data/invalid" - }]) + dataset.create_data_rows( + [{"row_data": "s3://labelbox-public-data/invalid"}] + ) assert "s3" in exc.value.message @@ -176,15 +177,12 @@ def test_data_row_bulk_creation(dataset, rand_gen, image_url): try: payload = [ - { - DataRow.row_data: image_url - }, - { - "row_data": image_url - }, + {DataRow.row_data: image_url}, + {"row_data": image_url}, ] - with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES', - new=300): # To make 2 chunks + with patch( + "labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES", new=300 + ): # To make 2 chunks # Test creation using URL task = dataset.create_data_rows(payload, file_upload_thread_count=2) task.wait_till_done() @@ -225,10 +223,12 @@ def local_image_file(image_url) -> NamedTemporaryFile: def test_data_row_bulk_creation_from_file(dataset, local_image_file, image_url): - with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES', - new=500): # Force chunking + with patch( + "labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES", new=500 + ): # Force chunking task = dataset.create_data_rows( - [local_image_file.name, local_image_file.name]) + [local_image_file.name, local_image_file.name] + ) task.wait_till_done() assert task.status == "COMPLETE" assert len(task.result) == 2 @@ -239,16 +239,17 @@ def test_data_row_bulk_creation_from_file(dataset, local_image_file, image_url): def test_data_row_bulk_creation_from_row_data_file_external_id( - dataset, local_image_file, image_url): - with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES', - new=500): # Force chunking - task = dataset.create_data_rows([{ - "row_data": local_image_file.name, - 'external_id': 'some_name' - }, { - "row_data": image_url, - 'external_id': 'some_name2' - }]) + dataset, local_image_file, image_url +): + with patch( + "labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES", new=500 + ): # Force chunking + task = dataset.create_data_rows( + [ + {"row_data": local_image_file.name, "external_id": "some_name"}, + {"row_data": image_url, "external_id": "some_name2"}, + ] + ) task.wait_till_done() assert task.status == "COMPLETE" assert len(task.result) == 2 @@ -259,15 +260,18 @@ def test_data_row_bulk_creation_from_row_data_file_external_id( assert image_url in row_data -def test_data_row_bulk_creation_from_row_data_file(dataset, rand_gen, - local_image_file, image_url): - with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES', - new=500): # Force chunking - task = dataset.create_data_rows([{ - "row_data": local_image_file.name - }, { - "row_data": local_image_file.name - }]) +def test_data_row_bulk_creation_from_row_data_file( + dataset, rand_gen, local_image_file, image_url +): + with patch( + "labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES", new=500 + ): # Force chunking + task = dataset.create_data_rows( + [ + {"row_data": local_image_file.name}, + {"row_data": local_image_file.name}, + ] + ) task.wait_till_done() assert task.status == "COMPLETE" assert len(task.result) == 2 @@ -285,9 +289,9 @@ def test_data_row_large_bulk_creation(dataset, image_url): with NamedTemporaryFile() as fp: fp.write("Test data".encode()) fp.flush() - task = dataset.create_data_rows([{ - DataRow.row_data: image_url - }] * n_urls + [fp.name] * n_local) + task = dataset.create_data_rows( + [{DataRow.row_data: image_url}] * n_urls + [fp.name] * n_local + ) task.wait_till_done() assert task.status == "COMPLETE" assert len(list(dataset.data_rows())) == n_local + n_urls @@ -302,8 +306,10 @@ def test_data_row_single_creation(dataset, rand_gen, image_url): assert data_row.dataset() == dataset assert data_row.created_by() == client.get_user() assert data_row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(data_row.row_data).content + assert ( + requests.get(image_url).content + == requests.get(data_row.row_data).content + ) assert data_row.media_attributes is not None assert data_row.global_key is None @@ -325,8 +331,10 @@ def test_create_data_row_with_dict(dataset, image_url): assert data_row.dataset() == dataset assert data_row.created_by() == client.get_user() assert data_row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(data_row.row_data).content + assert ( + requests.get(image_url).content + == requests.get(data_row.row_data).content + ) assert data_row.media_attributes is not None @@ -339,8 +347,10 @@ def test_create_data_row_with_dict_containing_field(dataset, image_url): assert data_row.dataset() == dataset assert data_row.created_by() == client.get_user() assert data_row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(data_row.row_data).content + assert ( + requests.get(image_url).content + == requests.get(data_row.row_data).content + ) assert data_row.media_attributes is not None @@ -353,8 +363,10 @@ def test_create_data_row_with_dict_unpacked(dataset, image_url): assert data_row.dataset() == dataset assert data_row.created_by() == client.get_user() assert data_row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(data_row.row_data).content + assert ( + requests.get(image_url).content + == requests.get(data_row.row_data).content + ) assert data_row.media_attributes is not None @@ -367,22 +379,26 @@ def test_create_data_row_with_metadata(mdo, dataset, image_url): client = dataset.client assert len(list(dataset.data_rows())) == 0 - data_row = dataset.create_data_row(row_data=image_url, - metadata_fields=make_metadata_fields()) + data_row = dataset.create_data_row( + row_data=image_url, metadata_fields=make_metadata_fields() + ) assert len(list(dataset.data_rows())) == 1 assert data_row.dataset() == dataset assert data_row.created_by() == client.get_user() assert data_row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(data_row.row_data).content + assert ( + requests.get(image_url).content + == requests.get(data_row.row_data).content + ) assert data_row.media_attributes is not None metadata_fields = data_row.metadata_fields metadata = data_row.metadata assert len(metadata_fields) == 3 assert len(metadata) == 3 - assert [m["schemaId"] for m in metadata_fields - ].sort() == EXPECTED_METADATA_SCHEMA_IDS + assert [ + m["schemaId"] for m in metadata_fields + ].sort() == EXPECTED_METADATA_SCHEMA_IDS for m in metadata: assert mdo._parse_upsert(m) @@ -392,21 +408,25 @@ def test_create_data_row_with_metadata_dict(mdo, dataset, image_url): assert len(list(dataset.data_rows())) == 0 data_row = dataset.create_data_row( - row_data=image_url, metadata_fields=make_metadata_fields_dict()) + row_data=image_url, metadata_fields=make_metadata_fields_dict() + ) assert len(list(dataset.data_rows())) == 1 assert data_row.dataset() == dataset assert data_row.created_by() == client.get_user() assert data_row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(data_row.row_data).content + assert ( + requests.get(image_url).content + == requests.get(data_row.row_data).content + ) assert data_row.media_attributes is not None metadata_fields = data_row.metadata_fields metadata = data_row.metadata assert len(metadata_fields) == 3 assert len(metadata) == 3 - assert [m["schemaId"] for m in metadata_fields - ].sort() == EXPECTED_METADATA_SCHEMA_IDS + assert [ + m["schemaId"] for m in metadata_fields + ].sort() == EXPECTED_METADATA_SCHEMA_IDS for m in metadata: assert mdo._parse_upsert(m) @@ -415,7 +435,8 @@ def test_create_data_row_with_invalid_metadata(dataset, image_url): fields = make_metadata_fields() # make the payload invalid by providing the same schema id more than once fields.append( - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value='some msg')) + DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value="some msg") + ) with pytest.raises(ResourceCreationError): dataset.create_data_row(row_data=image_url, metadata_fields=fields) @@ -425,28 +446,30 @@ def test_create_data_rows_with_metadata(mdo, dataset, image_url): client = dataset.client assert len(list(dataset.data_rows())) == 0 - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url, - DataRow.external_id: "row1", - DataRow.metadata_fields: make_metadata_fields() - }, - { - DataRow.row_data: image_url, - DataRow.external_id: "row2", - "metadata_fields": make_metadata_fields() - }, - { - DataRow.row_data: image_url, - DataRow.external_id: "row3", - DataRow.metadata_fields: make_metadata_fields_dict() - }, - { - DataRow.row_data: image_url, - DataRow.external_id: "row4", - "metadata_fields": make_metadata_fields_dict() - }, - ]) + task = dataset.create_data_rows( + [ + { + DataRow.row_data: image_url, + DataRow.external_id: "row1", + DataRow.metadata_fields: make_metadata_fields(), + }, + { + DataRow.row_data: image_url, + DataRow.external_id: "row2", + "metadata_fields": make_metadata_fields(), + }, + { + DataRow.row_data: image_url, + DataRow.external_id: "row3", + DataRow.metadata_fields: make_metadata_fields_dict(), + }, + { + DataRow.row_data: image_url, + DataRow.external_id: "row4", + "metadata_fields": make_metadata_fields_dict(), + }, + ] + ) task.wait_till_done() assert len(list(dataset.data_rows())) == 4 @@ -455,63 +478,60 @@ def test_create_data_rows_with_metadata(mdo, dataset, image_url): assert row.dataset() == dataset assert row.created_by() == client.get_user() assert row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(row.row_data).content + assert ( + requests.get(image_url).content + == requests.get(row.row_data).content + ) assert row.media_attributes is not None metadata_fields = row.metadata_fields metadata = row.metadata assert len(metadata_fields) == 3 assert len(metadata) == 3 - assert [m["schemaId"] for m in metadata_fields - ].sort() == EXPECTED_METADATA_SCHEMA_IDS + assert [ + m["schemaId"] for m in metadata_fields + ].sort() == EXPECTED_METADATA_SCHEMA_IDS for m in metadata: assert mdo._parse_upsert(m) -@pytest.mark.parametrize("test_function,metadata_obj_type", - [("create_data_rows", "class"), - ("create_data_rows", "dict"), - ("create_data_rows_sync", "class"), - ("create_data_rows_sync", "dict"), - ("create_data_row", "class"), - ("create_data_row", "dict")]) +@pytest.mark.parametrize( + "test_function,metadata_obj_type", + [ + ("create_data_rows", "class"), + ("create_data_rows", "dict"), + ("create_data_rows_sync", "class"), + ("create_data_rows_sync", "dict"), + ("create_data_row", "class"), + ("create_data_row", "dict"), + ], +) def test_create_data_rows_with_named_metadata_field_class( - test_function, metadata_obj_type, mdo, dataset, image_url): - + test_function, metadata_obj_type, mdo, dataset, image_url +): row_with_metadata_field = { - DataRow.row_data: - image_url, - DataRow.external_id: - "row1", + DataRow.row_data: image_url, + DataRow.external_id: "row1", DataRow.metadata_fields: [ - DataRowMetadataField(name='split', value='test'), - DataRowMetadataField(name=CUSTOM_TEXT_SCHEMA_NAME, value='hello') - ] + DataRowMetadataField(name="split", value="test"), + DataRowMetadataField(name=CUSTOM_TEXT_SCHEMA_NAME, value="hello"), + ], } row_with_metadata_dict = { - DataRow.row_data: - image_url, - DataRow.external_id: - "row2", + DataRow.row_data: image_url, + DataRow.external_id: "row2", "metadata_fields": [ - { - 'name': 'split', - 'value': 'test' - }, - { - 'name': CUSTOM_TEXT_SCHEMA_NAME, - 'value': 'hello' - }, - ] + {"name": "split", "value": "test"}, + {"name": CUSTOM_TEXT_SCHEMA_NAME, "value": "hello"}, + ], } assert len(list(dataset.data_rows())) == 0 METADATA_FIELDS = { "class": row_with_metadata_field, - "dict": row_with_metadata_dict + "dict": row_with_metadata_dict, } def create_data_row(data_rows): @@ -520,7 +540,7 @@ def create_data_row(data_rows): CREATION_FUNCTION = { "create_data_rows": dataset.create_data_rows, "create_data_rows_sync": dataset.create_data_rows_sync, - "create_data_row": create_data_row + "create_data_row": create_data_row, } data_rows = [METADATA_FIELDS[metadata_obj_type]] function_to_test = CREATION_FUNCTION[test_function] @@ -536,30 +556,33 @@ def create_data_row(data_rows): metadata = created_rows[0].metadata assert metadata[0].schema_id == SPLIT_SCHEMA_ID - assert metadata[0].name == 'test' - assert metadata[0].value == mdo.reserved_by_name['split']['test'].uid + assert metadata[0].name == "test" + assert metadata[0].value == mdo.reserved_by_name["split"]["test"].uid assert metadata[1].name == CUSTOM_TEXT_SCHEMA_NAME - assert metadata[1].value == 'hello' - assert metadata[1].schema_id == mdo.custom_by_name[ - CUSTOM_TEXT_SCHEMA_NAME].uid + assert metadata[1].value == "hello" + assert ( + metadata[1].schema_id == mdo.custom_by_name[CUSTOM_TEXT_SCHEMA_NAME].uid + ) def test_create_data_rows_with_invalid_metadata(dataset, image_url): fields = make_metadata_fields() # make the payload invalid by providing the same schema id more than once fields.append( - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value="some msg")) + DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value="some msg") + ) - task = dataset.create_data_rows([{ - DataRow.row_data: image_url, - DataRow.metadata_fields: fields - }]) + task = dataset.create_data_rows( + [{DataRow.row_data: image_url, DataRow.metadata_fields: fields}] + ) task.wait_till_done(timeout_seconds=60) assert task.status == "COMPLETE" assert len(task.failed_data_rows) == 1 - assert f"A schemaId can only be specified once per DataRow : [{TEXT_SCHEMA_ID}]" in task.failed_data_rows[ - 0]["message"] + assert ( + f"A schemaId can only be specified once per DataRow : [{TEXT_SCHEMA_ID}]" + in task.failed_data_rows[0]["message"] + ) def test_create_data_rows_with_metadata_missing_value(dataset, image_url): @@ -567,13 +590,15 @@ def test_create_data_rows_with_metadata_missing_value(dataset, image_url): fields.append({"schemaId": "some schema id"}) with pytest.raises(ValueError) as exc: - dataset.create_data_rows([ - { - DataRow.row_data: image_url, - DataRow.external_id: "row1", - DataRow.metadata_fields: fields - }, - ]) + dataset.create_data_rows( + [ + { + DataRow.row_data: image_url, + DataRow.external_id: "row1", + DataRow.metadata_fields: fields, + }, + ] + ) def test_create_data_rows_with_metadata_missing_schema_id(dataset, image_url): @@ -581,13 +606,15 @@ def test_create_data_rows_with_metadata_missing_schema_id(dataset, image_url): fields.append({"value": "some value"}) with pytest.raises(ValueError) as exc: - dataset.create_data_rows([ - { - DataRow.row_data: image_url, - DataRow.external_id: "row1", - DataRow.metadata_fields: fields - }, - ]) + dataset.create_data_rows( + [ + { + DataRow.row_data: image_url, + DataRow.external_id: "row1", + DataRow.metadata_fields: fields, + }, + ] + ) def test_create_data_rows_with_metadata_wrong_type(dataset, image_url): @@ -595,20 +622,24 @@ def test_create_data_rows_with_metadata_wrong_type(dataset, image_url): fields.append("Neither DataRowMetadataField or dict") with pytest.raises(ValueError) as exc: - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url, - DataRow.external_id: "row1", - DataRow.metadata_fields: fields - }, - ]) + task = dataset.create_data_rows( + [ + { + DataRow.row_data: image_url, + DataRow.external_id: "row1", + DataRow.metadata_fields: fields, + }, + ] + ) def test_data_row_update_missing_or_empty_required_fields( - dataset, rand_gen, image_url): + dataset, rand_gen, image_url +): external_id = rand_gen(str) - data_row = dataset.create_data_row(row_data=image_url, - external_id=external_id) + data_row = dataset.create_data_row( + row_data=image_url, external_id=external_id + ) with pytest.raises(ValueError): data_row.update(row_data="") with pytest.raises(ValueError): @@ -621,11 +652,13 @@ def test_data_row_update_missing_or_empty_required_fields( data_row.update() -def test_data_row_update(client, dataset, rand_gen, image_url, - wait_for_data_row_processing): +def test_data_row_update( + client, dataset, rand_gen, image_url, wait_for_data_row_processing +): external_id = rand_gen(str) - data_row = dataset.create_data_row(row_data=image_url, - external_id=external_id) + data_row = dataset.create_data_row( + row_data=image_url, external_id=external_id + ) assert data_row.external_id == external_id external_id_2 = rand_gen(str) @@ -643,25 +676,23 @@ def test_data_row_update(client, dataset, rand_gen, image_url, # tileLayer becomes a media attribute pdf_url = "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf" tileLayerUrl = "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483-lb-textlayer.json" - data_row.update(row_data={'pdfUrl': pdf_url, "tileLayerUrl": tileLayerUrl}) - custom_check = lambda data_row: data_row.row_data and 'pdfUrl' not in data_row.row_data - data_row = wait_for_data_row_processing(client, - data_row, - custom_check=custom_check) + data_row.update(row_data={"pdfUrl": pdf_url, "tileLayerUrl": tileLayerUrl}) + custom_check = ( + lambda data_row: data_row.row_data and "pdfUrl" not in data_row.row_data + ) + data_row = wait_for_data_row_processing( + client, data_row, custom_check=custom_check + ) assert data_row.row_data == pdf_url def test_data_row_filtering_sorting(dataset, image_url): - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url, - DataRow.external_id: "row1" - }, - { - DataRow.row_data: image_url, - DataRow.external_id: "row2" - }, - ]) + task = dataset.create_data_rows( + [ + {DataRow.row_data: image_url, DataRow.external_id: "row1"}, + {DataRow.row_data: image_url, DataRow.external_id: "row2"}, + ] + ) task.wait_till_done() # Test filtering @@ -681,10 +712,12 @@ def test_data_row_filtering_sorting(dataset, image_url): @pytest.fixture def create_datarows_for_data_row_deletion(dataset, image_url): - task = dataset.create_data_rows([{ - DataRow.row_data: image_url, - DataRow.external_id: str(i) - } for i in range(10)]) + task = dataset.create_data_rows( + [ + {DataRow.row_data: image_url, DataRow.external_id: str(i)} + for i in range(10) + ] + ) task.wait_till_done() data_rows = list(dataset.data_rows()) @@ -716,34 +749,39 @@ def test_data_row_deletion(dataset, create_datarows_for_data_row_deletion): def test_data_row_iteration(dataset, image_url) -> None: - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url - }, - { - "row_data": image_url - }, - ]) + task = dataset.create_data_rows( + [ + {DataRow.row_data: image_url}, + {"row_data": image_url}, + ] + ) task.wait_till_done() assert next(dataset.data_rows()) def test_data_row_attachments(dataset, image_url): - attachments = [("IMAGE", image_url, "attachment image"), - ("RAW_TEXT", "test-text", None), - ("IMAGE_OVERLAY", image_url, "Overlay"), - ("HTML", image_url, None)] - task = dataset.create_data_rows([{ - "row_data": - image_url, - "external_id": - "test-id", - "attachments": [{ - "type": attachment_type, - "value": attachment_value, - "name": attachment_name - }] - } for attachment_type, attachment_value, attachment_name in attachments]) + attachments = [ + ("IMAGE", image_url, "attachment image"), + ("RAW_TEXT", "test-text", None), + ("IMAGE_OVERLAY", image_url, "Overlay"), + ("HTML", image_url, None), + ] + task = dataset.create_data_rows( + [ + { + "row_data": image_url, + "external_id": "test-id", + "attachments": [ + { + "type": attachment_type, + "value": attachment_value, + "name": attachment_name, + } + ], + } + for attachment_type, attachment_value, attachment_name in attachments + ] + ) task.wait_till_done() assert task.status == "COMPLETE" @@ -754,33 +792,42 @@ def test_data_row_attachments(dataset, image_url): assert data_row.external_id == "test-id" with pytest.raises(ValueError) as exc: - task = dataset.create_data_rows([{ - "row_data": image_url, - "external_id": "test-id", - "attachments": [{ - "type": "INVALID", - "value": "123" - }] - }]) + task = dataset.create_data_rows( + [ + { + "row_data": image_url, + "external_id": "test-id", + "attachments": [{"type": "INVALID", "value": "123"}], + } + ] + ) def test_create_data_rows_sync_attachments(dataset, image_url): - attachments = [("IMAGE", image_url, "image URL"), - ("RAW_TEXT", "test-text", None), - ("IMAGE_OVERLAY", image_url, "Overlay"), - ("HTML", image_url, None)] + attachments = [ + ("IMAGE", image_url, "image URL"), + ("RAW_TEXT", "test-text", None), + ("IMAGE_OVERLAY", image_url, "Overlay"), + ("HTML", image_url, None), + ] attachments_per_data_row = 3 - dataset.create_data_rows_sync([{ - "row_data": - image_url, - "external_id": - "test-id", - "attachments": [{ - "type": attachment_type, - "value": attachment_value, - "name": attachment_name - } for _ in range(attachments_per_data_row)] - } for attachment_type, attachment_value, attachment_name in attachments]) + dataset.create_data_rows_sync( + [ + { + "row_data": image_url, + "external_id": "test-id", + "attachments": [ + { + "type": attachment_type, + "value": attachment_value, + "name": attachment_name, + } + for _ in range(attachments_per_data_row) + ], + } + for attachment_type, attachment_value, attachment_name in attachments + ] + ) data_rows = list(dataset.data_rows()) assert len(data_rows) == len(attachments) for data_row in data_rows: @@ -793,15 +840,16 @@ def test_create_data_rows_sync_mixed_upload(dataset, image_url): with NamedTemporaryFile() as fp: fp.write("Test data".encode()) fp.flush() - dataset.create_data_rows_sync([{ - DataRow.row_data: image_url - }] * n_urls + [fp.name] * n_local) + dataset.create_data_rows_sync( + [{DataRow.row_data: image_url}] * n_urls + [fp.name] * n_local + ) assert len(list(dataset.data_rows())) == n_local + n_urls def test_create_data_row_attachment(data_row): - att = data_row.create_attachment("IMAGE", "https://example.com/image.jpg", - "name") + att = data_row.create_attachment( + "IMAGE", "https://example.com/image.jpg", "name" + ) assert att.attachment_type == "IMAGE" assert att.attachment_value == "https://example.com/image.jpg" assert att.attachment_name == "name" @@ -823,21 +871,30 @@ def test_delete_data_row_attachment(data_row, image_url): attachments = [] # Anonymous attachment - to_attach = [("IMAGE", image_url), ("RAW_TEXT", "test-text"), - ("IMAGE_OVERLAY", image_url), ("HTML", image_url)] + to_attach = [ + ("IMAGE", image_url), + ("RAW_TEXT", "test-text"), + ("IMAGE_OVERLAY", image_url), + ("HTML", image_url), + ] for attachment_type, attachment_value in to_attach: attachments.append( - data_row.create_attachment(attachment_type, attachment_value)) + data_row.create_attachment(attachment_type, attachment_value) + ) # Attachment with a name - to_attach = [("IMAGE", image_url, "Att. Image"), - ("RAW_TEXT", "test-text", "Att. Text"), - ("IMAGE_OVERLAY", image_url, "Image Overlay"), - ("HTML", image_url, "Att. HTML")] + to_attach = [ + ("IMAGE", image_url, "Att. Image"), + ("RAW_TEXT", "test-text", "Att. Text"), + ("IMAGE_OVERLAY", image_url, "Image Overlay"), + ("HTML", image_url, "Att. HTML"), + ] for attachment_type, attachment_value, attachment_name in to_attach: attachments.append( - data_row.create_attachment(attachment_type, attachment_value, - attachment_name)) + data_row.create_attachment( + attachment_type, attachment_value, attachment_name + ) + ) for attachment in attachments: attachment.delete() @@ -847,7 +904,8 @@ def test_delete_data_row_attachment(data_row, image_url): def test_update_data_row_attachment(data_row, image_url): attachment: AssetAttachment = data_row.create_attachment( - "RAW_TEXT", "value", "name") + "RAW_TEXT", "value", "name" + ) assert attachment is not None attachment.update(name="updated name", type="IMAGE", value=image_url) assert attachment.attachment_name == "updated name" @@ -857,7 +915,8 @@ def test_update_data_row_attachment(data_row, image_url): def test_update_data_row_attachment_invalid_type(data_row): attachment: AssetAttachment = data_row.create_attachment( - "RAW_TEXT", "value", "name") + "RAW_TEXT", "value", "name" + ) assert attachment is not None with pytest.raises(ValueError): attachment.update(name="updated name", type="INVALID", value="value") @@ -865,7 +924,8 @@ def test_update_data_row_attachment_invalid_type(data_row): def test_update_data_row_attachment_invalid_value(data_row): attachment: AssetAttachment = data_row.create_attachment( - "RAW_TEXT", "value", "name") + "RAW_TEXT", "value", "name" + ) assert attachment is not None with pytest.raises(ValueError): attachment.update(name="updated name", type="IMAGE", value="") @@ -873,7 +933,8 @@ def test_update_data_row_attachment_invalid_value(data_row): def test_does_not_update_not_provided_attachment_fields(data_row): attachment: AssetAttachment = data_row.create_attachment( - "RAW_TEXT", "value", "name") + "RAW_TEXT", "value", "name" + ) assert attachment is not None attachment.update(value=None, name="name") assert attachment.attachment_value == "value" @@ -884,27 +945,33 @@ def test_does_not_update_not_provided_attachment_fields(data_row): def test_create_data_rows_result(client, dataset, image_url): - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url, - DataRow.external_id: "row1", - }, - { - DataRow.row_data: image_url, - DataRow.external_id: "row1", - }, - ]) + task = dataset.create_data_rows( + [ + { + DataRow.row_data: image_url, + DataRow.external_id: "row1", + }, + { + DataRow.row_data: image_url, + DataRow.external_id: "row1", + }, + ] + ) task.wait_till_done() assert task.errors is None for result in task.result: - client.get_data_row(result['id']) + client.get_data_row(result["id"]) def test_create_data_rows_local_file(dataset, sample_image): - task = dataset.create_data_rows([{ - DataRow.row_data: sample_image, - DataRow.metadata_fields: make_metadata_fields() - }]) + task = dataset.create_data_rows( + [ + { + DataRow.row_data: sample_image, + DataRow.metadata_fields: make_metadata_fields(), + } + ] + ) task.wait_till_done() assert task.status == "COMPLETE" data_row = list(dataset.data_rows())[0] @@ -914,10 +981,9 @@ def test_create_data_rows_local_file(dataset, sample_image): def test_data_row_with_global_key(dataset, sample_image): global_key = str(uuid.uuid4()) - row = dataset.create_data_row({ - DataRow.row_data: sample_image, - DataRow.global_key: global_key - }) + row = dataset.create_data_row( + {DataRow.row_data: sample_image, DataRow.global_key: global_key} + ) assert row.global_key == global_key @@ -927,36 +993,32 @@ def test_data_row_bulk_creation_with_unique_global_keys(dataset, sample_image): global_key_2 = str(uuid.uuid4()) global_key_3 = str(uuid.uuid4()) - task = dataset.create_data_rows([ - { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 - }, - { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_2 - }, - { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_3 - }, - ]) + task = dataset.create_data_rows( + [ + {DataRow.row_data: sample_image, DataRow.global_key: global_key_1}, + {DataRow.row_data: sample_image, DataRow.global_key: global_key_2}, + {DataRow.row_data: sample_image, DataRow.global_key: global_key_3}, + ] + ) task.wait_till_done() - assert {row.global_key for row in dataset.data_rows() - } == {global_key_1, global_key_2, global_key_3} + assert {row.global_key for row in dataset.data_rows()} == { + global_key_1, + global_key_2, + global_key_3, + } -def test_data_row_bulk_creation_with_same_global_keys(dataset, sample_image, - snapshot): +def test_data_row_bulk_creation_with_same_global_keys( + dataset, sample_image, snapshot +): global_key_1 = str(uuid.uuid4()) - task = dataset.create_data_rows([{ - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 - }, { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 - }]) + task = dataset.create_data_rows( + [ + {DataRow.row_data: sample_image, DataRow.global_key: global_key_1}, + {DataRow.row_data: sample_image, DataRow.global_key: global_key_1}, + ] + ) task.wait_till_done() @@ -965,12 +1027,16 @@ def test_data_row_bulk_creation_with_same_global_keys(dataset, sample_image, assert len(task.failed_data_rows) == 1 assert type(task.created_data_rows) is list assert len(task.created_data_rows) == 1 - assert task.failed_data_rows[0][ - 'message'] == f"Duplicate global key: '{global_key_1}'" - assert task.failed_data_rows[0]['failedDataRows'][0][ - 'externalId'] == sample_image - assert task.created_data_rows[0]['external_id'] == sample_image - assert task.created_data_rows[0]['global_key'] == global_key_1 + assert ( + task.failed_data_rows[0]["message"] + == f"Duplicate global key: '{global_key_1}'" + ) + assert ( + task.failed_data_rows[0]["failedDataRows"][0]["externalId"] + == sample_image + ) + assert task.created_data_rows[0]["external_id"] == sample_image + assert task.created_data_rows[0]["global_key"] == global_key_1 assert len(task.errors) == 1 assert task.has_errors() is True @@ -980,11 +1046,12 @@ def test_data_row_bulk_creation_with_same_global_keys(dataset, sample_image, def test_data_row_delete_and_create_with_same_global_key( - client, dataset, sample_image): + client, dataset, sample_image +): global_key_1 = str(uuid.uuid4()) data_row_payload = { DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 + DataRow.global_key: global_key_1, } # should successfully insert new datarow @@ -992,9 +1059,9 @@ def test_data_row_delete_and_create_with_same_global_key( task.wait_till_done() assert task.status == "COMPLETE" - assert task.result[0]['global_key'] == global_key_1 + assert task.result[0]["global_key"] == global_key_1 - new_data_row_id = task.result[0]['id'] + new_data_row_id = task.result[0]["id"] # same payload should fail due to duplicated global key task = dataset.create_data_rows([data_row_payload]) @@ -1002,8 +1069,10 @@ def test_data_row_delete_and_create_with_same_global_key( assert task.status == "COMPLETE" assert len(task.failed_data_rows) == 1 - assert task.failed_data_rows[0][ - 'message'] == f"Duplicate global key: '{global_key_1}'" + assert ( + task.failed_data_rows[0]["message"] + == f"Duplicate global key: '{global_key_1}'" + ) # delete datarow client.get_data_row(new_data_row_id).delete() @@ -1013,46 +1082,49 @@ def test_data_row_delete_and_create_with_same_global_key( task.wait_till_done() assert task.status == "COMPLETE" - assert task.result[0]['global_key'] == global_key_1 + assert task.result[0]["global_key"] == global_key_1 def test_data_row_bulk_creation_sync_with_unique_global_keys( - dataset, sample_image): + dataset, sample_image +): global_key_1 = str(uuid.uuid4()) global_key_2 = str(uuid.uuid4()) global_key_3 = str(uuid.uuid4()) - dataset.create_data_rows_sync([ - { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 - }, - { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_2 - }, - { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_3 - }, - ]) + dataset.create_data_rows_sync( + [ + {DataRow.row_data: sample_image, DataRow.global_key: global_key_1}, + {DataRow.row_data: sample_image, DataRow.global_key: global_key_2}, + {DataRow.row_data: sample_image, DataRow.global_key: global_key_3}, + ] + ) - assert {row.global_key for row in dataset.data_rows() - } == {global_key_1, global_key_2, global_key_3} + assert {row.global_key for row in dataset.data_rows()} == { + global_key_1, + global_key_2, + global_key_3, + } def test_data_row_bulk_creation_sync_with_same_global_keys( - dataset, sample_image): + dataset, sample_image +): global_key_1 = str(uuid.uuid4()) with pytest.raises(ResourceCreationError) as exc_info: - dataset.create_data_rows_sync([{ - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 - }, { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 - }]) + dataset.create_data_rows_sync( + [ + { + DataRow.row_data: sample_image, + DataRow.global_key: global_key_1, + }, + { + DataRow.row_data: sample_image, + DataRow.global_key: global_key_1, + }, + ] + ) assert len(list(dataset.data_rows())) == 1 assert list(dataset.data_rows())[0].global_key == global_key_1 @@ -1064,13 +1136,13 @@ def test_data_row_bulk_creation_sync_with_same_global_keys( def conversational_data_rows(dataset, conversational_content): examples = [ { - **conversational_content, 'media_type': - MediaType.Conversational.value + **conversational_content, + "media_type": MediaType.Conversational.value, }, conversational_content, { - "conversationalData": conversational_content['row_data']['messages'] - } # Old way to check for backwards compatibility + "conversationalData": conversational_content["row_data"]["messages"] + }, # Old way to check for backwards compatibility ] task = dataset.create_data_rows(examples) task.wait_till_done() @@ -1083,49 +1155,47 @@ def conversational_data_rows(dataset, conversational_content): dr.delete() -def test_create_conversational_text(conversational_data_rows, - conversational_content): +def test_create_conversational_text( + conversational_data_rows, conversational_content +): data_rows = conversational_data_rows for data_row in data_rows: - assert json.loads( - data_row.row_data) == conversational_content['row_data'] + assert ( + json.loads(data_row.row_data) == conversational_content["row_data"] + ) def test_invalid_media_type(dataset, conversational_content): - for _, __ in [["Found invalid contents for media type: 'IMAGE'", 'IMAGE'], - [ - "Found invalid media type: 'totallyinvalid'", - 'totallyinvalid' - ]]: + for _, __ in [ + ["Found invalid contents for media type: 'IMAGE'", "IMAGE"], + ["Found invalid media type: 'totallyinvalid'", "totallyinvalid"], + ]: # TODO: What error kind should this be? It looks like for global key we are # using malformed query. But for invalid contents in FileUploads we use InvalidQueryError with pytest.raises(ResourceCreationError): - dataset.create_data_rows_sync([{ - **conversational_content, 'media_type': 'IMAGE' - }]) + dataset.create_data_rows_sync( + [{**conversational_content, "media_type": "IMAGE"}] + ) def test_create_tiled_layer(dataset, tile_content): examples = [ - { - **tile_content, 'media_type': 'TMS_GEO' - }, + {**tile_content, "media_type": "TMS_GEO"}, tile_content, ] dataset.create_data_rows_sync(examples) data_rows = list(dataset.data_rows()) assert len(data_rows) == len(examples) for data_row in data_rows: - assert json.loads(data_row.row_data) == tile_content['row_data'] + assert json.loads(data_row.row_data) == tile_content["row_data"] def test_create_data_row_with_attachments(dataset): - attachment_value = 'attachment value' - dr = dataset.create_data_row(row_data="123", - attachments=[{ - 'type': 'RAW_TEXT', - 'value': attachment_value - }]) + attachment_value = "attachment value" + dr = dataset.create_data_row( + row_data="123", + attachments=[{"type": "RAW_TEXT", "value": attachment_value}], + ) attachments = list(dr.attachments()) assert len(attachments) == 1 @@ -1133,7 +1203,8 @@ def test_create_data_row_with_attachments(dataset): def test_create_data_row_with_media_type(dataset, image_url): with pytest.raises(ResourceCreationError) as exc: dr = dataset.create_data_row( - row_data={'invalid_object': 'invalid_value'}, media_type="IMAGE") + row_data={"invalid_object": "invalid_value"}, media_type="IMAGE" + ) assert "Expected type image/*, detected: application/json" in str(exc.value) diff --git a/libs/labelbox/tests/integration/test_data_rows_upsert.py b/libs/labelbox/tests/integration/test_data_rows_upsert.py index da99eecc6..2ba7a9df9 100644 --- a/libs/labelbox/tests/integration/test_data_rows_upsert.py +++ b/libs/labelbox/tests/integration/test_data_rows_upsert.py @@ -9,87 +9,70 @@ class TestDataRowUpsert: - @pytest.fixture def all_inclusive_data_row(self, dataset, image_url): dr = dataset.create_data_row( row_data=image_url, external_id="ex1", global_key=str(uuid.uuid4()), - metadata_fields=[{ - "name": "tag", - "value": "tag_string" - }, { - "name": "split", - "value": "train" - }], + metadata_fields=[ + {"name": "tag", "value": "tag_string"}, + {"name": "split", "value": "train"}, + ], attachments=[ + {"type": "RAW_TEXT", "name": "att1", "value": "test1"}, { - "type": "RAW_TEXT", - "name": "att1", - "value": "test1" - }, - { - "type": - "IMAGE", - "name": - "att2", - "value": - "https://storage.googleapis.com/labelbox-sample-datasets/Docs/disease_attachment.jpeg" + "type": "IMAGE", + "name": "att2", + "value": "https://storage.googleapis.com/labelbox-sample-datasets/Docs/disease_attachment.jpeg", }, { - "type": - "PDF_URL", - "name": - "att3", - "value": - "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf" + "type": "PDF_URL", + "name": "att3", + "value": "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf", }, - ]) + ], + ) return dr @pytest.mark.order(1) def test_create_data_row_with_auto_key(self, dataset, image_url): - task = dataset.upsert_data_rows([{'row_data': image_url}]) + task = dataset.upsert_data_rows([{"row_data": image_url}]) task.wait_till_done() assert len(list(dataset.data_rows())) == 1 def test_create_data_row_with_upsert(self, client, dataset, image_url): gkey = str(uuid.uuid4()) - task = dataset.upsert_data_rows([{ - 'row_data': - image_url, - 'global_key': - gkey, - 'external_id': - "ex1", - 'attachments': [{ - 'type': AttachmentType.RAW_TEXT, - 'name': "att1", - 'value': "test1" - }, { - 'type': - AttachmentType.IMAGE, - 'name': - "att2", - 'value': - "https://storage.googleapis.com/labelbox-sample-datasets/Docs/disease_attachment.jpeg" - }, { - 'type': - AttachmentType.PDF_URL, - 'name': - "att3", - 'value': - "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf" - }], - 'metadata': [{ - 'name': "tag", - 'value': "updated tag" - }, { - 'name': "split", - 'value': "train" - }] - }]) + task = dataset.upsert_data_rows( + [ + { + "row_data": image_url, + "global_key": gkey, + "external_id": "ex1", + "attachments": [ + { + "type": AttachmentType.RAW_TEXT, + "name": "att1", + "value": "test1", + }, + { + "type": AttachmentType.IMAGE, + "name": "att2", + "value": "https://storage.googleapis.com/labelbox-sample-datasets/Docs/disease_attachment.jpeg", + }, + { + "type": AttachmentType.PDF_URL, + "name": "att3", + "value": "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf", + }, + ], + "metadata": [ + {"name": "tag", "value": "updated tag"}, + {"name": "split", "value": "train"}, + ], + } + ] + ) task.wait_till_done() assert task.status == "COMPLETE" dr = client.get_data_row_by_global_key(gkey) @@ -107,31 +90,40 @@ def test_create_data_row_with_upsert(self, client, dataset, image_url): assert attachments[1].attachment_name == "att2" assert attachments[1].attachment_type == AttachmentType.IMAGE - assert attachments[ - 1].attachment_value == "https://storage.googleapis.com/labelbox-sample-datasets/Docs/disease_attachment.jpeg" + assert ( + attachments[1].attachment_value + == "https://storage.googleapis.com/labelbox-sample-datasets/Docs/disease_attachment.jpeg" + ) assert attachments[2].attachment_name == "att3" assert attachments[2].attachment_type == AttachmentType.PDF_URL - assert attachments[ - 2].attachment_value == "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf" + assert ( + attachments[2].attachment_value + == "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf" + ) assert len(dr.metadata_fields) == 2 - assert dr.metadata_fields[0]['name'] == "tag" - assert dr.metadata_fields[0]['value'] == "updated tag" - assert dr.metadata_fields[1]['name'] == "split" - assert dr.metadata_fields[1]['value'] == "train" + assert dr.metadata_fields[0]["name"] == "tag" + assert dr.metadata_fields[0]["value"] == "updated tag" + assert dr.metadata_fields[1]["name"] == "split" + assert dr.metadata_fields[1]["value"] == "train" - def test_update_data_row_fields_with_upsert(self, client, dataset, - image_url): + def test_update_data_row_fields_with_upsert( + self, client, dataset, image_url + ): gkey = str(uuid.uuid4()) - dr = dataset.create_data_row(row_data=image_url, - external_id="ex1", - global_key=gkey) - task = dataset.upsert_data_rows([{ - 'key': UniqueId(dr.uid), - 'external_id': "ex1_updated", - 'global_key': f"{gkey}_updated" - }]) + dr = dataset.create_data_row( + row_data=image_url, external_id="ex1", global_key=gkey + ) + task = dataset.upsert_data_rows( + [ + { + "key": UniqueId(dr.uid), + "external_id": "ex1_updated", + "global_key": f"{gkey}_updated", + } + ] + ) task.wait_till_done() assert task.status == "COMPLETE" dr = client.get_data_row(dr.uid) @@ -140,16 +132,21 @@ def test_update_data_row_fields_with_upsert(self, client, dataset, assert dr.global_key == f"{gkey}_updated" def test_update_data_row_fields_with_upsert_by_global_key( - self, client, dataset, image_url): + self, client, dataset, image_url + ): gkey = str(uuid.uuid4()) - dr = dataset.create_data_row(row_data=image_url, - external_id="ex1", - global_key=gkey) - task = dataset.upsert_data_rows([{ - 'key': GlobalKey(dr.global_key), - 'external_id': "ex1_updated", - 'global_key': f"{gkey}_updated" - }]) + dr = dataset.create_data_row( + row_data=image_url, external_id="ex1", global_key=gkey + ) + task = dataset.upsert_data_rows( + [ + { + "key": GlobalKey(dr.global_key), + "external_id": "ex1_updated", + "global_key": f"{gkey}_updated", + } + ] + ) task.wait_till_done() assert task.status == "COMPLETE" dr = client.get_data_row(dr.uid) @@ -157,20 +154,25 @@ def test_update_data_row_fields_with_upsert_by_global_key( assert dr.external_id == "ex1_updated" assert dr.global_key == f"{gkey}_updated" - def test_update_attachments_with_upsert(self, client, - all_inclusive_data_row, dataset): + def test_update_attachments_with_upsert( + self, client, all_inclusive_data_row, dataset + ): dr = all_inclusive_data_row - task = dataset.upsert_data_rows([{ - 'key': - UniqueId(dr.uid), - 'row_data': - dr.row_data, - 'attachments': [{ - 'type': AttachmentType.RAW_TEXT, - 'name': "att1", - 'value': "test" - }] - }]) + task = dataset.upsert_data_rows( + [ + { + "key": UniqueId(dr.uid), + "row_data": dr.row_data, + "attachments": [ + { + "type": AttachmentType.RAW_TEXT, + "name": "att1", + "value": "test", + } + ], + } + ] + ) task.wait_till_done() assert task.status == "COMPLETE" dr = client.get_data_row(dr.uid) @@ -179,44 +181,49 @@ def test_update_attachments_with_upsert(self, client, assert len(attachments) == 1 assert attachments[0].attachment_name == "att1" - def test_update_metadata_with_upsert(self, client, all_inclusive_data_row, - dataset): + def test_update_metadata_with_upsert( + self, client, all_inclusive_data_row, dataset + ): dr = all_inclusive_data_row - task = dataset.upsert_data_rows([{ - 'key': - GlobalKey(dr.global_key), - 'row_data': - dr.row_data, - 'metadata': [{ - 'name': "tag", - 'value': "updated tag" - }, { - 'name': "split", - 'value': "train" - }] - }]) + task = dataset.upsert_data_rows( + [ + { + "key": GlobalKey(dr.global_key), + "row_data": dr.row_data, + "metadata": [ + {"name": "tag", "value": "updated tag"}, + {"name": "split", "value": "train"}, + ], + } + ] + ) task.wait_till_done() assert task.status == "COMPLETE" dr = client.get_data_row(dr.uid) assert dr is not None assert len(dr.metadata_fields) == 2 - assert dr.metadata_fields[0]['name'] == "tag" - assert dr.metadata_fields[0]['value'] == "updated tag" - assert dr.metadata_fields[1]['name'] == "split" - assert dr.metadata_fields[1]['value'] == "train" + assert dr.metadata_fields[0]["name"] == "tag" + assert dr.metadata_fields[0]["value"] == "updated tag" + assert dr.metadata_fields[1]["name"] == "split" + assert dr.metadata_fields[1]["value"] == "train" def test_multiple_chunks(self, client, dataset, image_url): mocked_chunk_size = 300 - with patch('labelbox.client.Client.upload_data', - wraps=client.upload_data) as spy_some_function: - with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES', - new=mocked_chunk_size): - task = dataset.upsert_data_rows([{ - 'row_data': image_url - } for i in range(10)]) + with patch( + "labelbox.client.Client.upload_data", wraps=client.upload_data + ) as spy_some_function: + with patch( + "labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES", + new=mocked_chunk_size, + ): + task = dataset.upsert_data_rows( + [{"row_data": image_url} for i in range(10)] + ) task.wait_till_done() assert len(list(dataset.data_rows())) == 10 - assert spy_some_function.call_count == 11 # one per each data row + manifest + assert ( + spy_some_function.call_count == 11 + ) # one per each data row + manifest first_call_args, _ = spy_some_function.call_args_list[0] first_chunk_content = first_call_args[0] @@ -228,23 +235,25 @@ def test_multiple_chunks(self, client, dataset, image_url): assert len(data) in {1, 3} last_call_args, _ = spy_some_function.call_args_list[-1] - manifest_content = last_call_args[0].decode('utf-8') + manifest_content = last_call_args[0].decode("utf-8") data = json.loads(manifest_content) - assert data['source'] == "SDK" - assert data['item_count'] == 10 - assert len(data['chunk_uris']) == 10 + assert data["source"] == "SDK" + assert data["item_count"] == 10 + assert len(data["chunk_uris"]) == 10 def test_upsert_embedded_row_data(self, dataset): pdf_url = "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/0801.3483.pdf" - task = dataset.upsert_data_rows([{ - 'row_data': { - "pdf_url": - pdf_url, - "text_layer_url": - "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/0801.3483-lb-textlayer.json" - }, - 'media_type': "PDF" - }]) + task = dataset.upsert_data_rows( + [ + { + "row_data": { + "pdf_url": pdf_url, + "text_layer_url": "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/0801.3483-lb-textlayer.json", + }, + "media_type": "PDF", + } + ] + ) task.wait_till_done() data_rows = list(dataset.data_rows()) assert len(data_rows) == 1 @@ -252,21 +261,17 @@ def test_upsert_embedded_row_data(self, dataset): def test_upsert_duplicate_global_key_error(self, dataset, image_url): gkey = str(uuid.uuid4()) - task = dataset.upsert_data_rows([ - { - 'row_data': image_url, - 'global_key': gkey - }, - { - 'row_data': image_url, - 'global_key': gkey - }, - ]) + task = dataset.upsert_data_rows( + [ + {"row_data": image_url, "global_key": gkey}, + {"row_data": image_url, "global_key": gkey}, + ] + ) task.wait_till_done() assert task.status == "COMPLETE" assert task.errors is not None assert len(task.errors) == 1 # one data row was created, one failed - assert f"Duplicate global key: '{gkey}'" in task.errors[0]['message'] + assert f"Duplicate global key: '{gkey}'" in task.errors[0]["message"] def test_upsert_empty_items(self, dataset): items = [{"key": GlobalKey("foo")}] diff --git a/libs/labelbox/tests/integration/test_dataset.py b/libs/labelbox/tests/integration/test_dataset.py index 51a43a09c..89210d6c9 100644 --- a/libs/labelbox/tests/integration/test_dataset.py +++ b/libs/labelbox/tests/integration/test_dataset.py @@ -4,11 +4,12 @@ from labelbox import Dataset from labelbox.exceptions import ResourceNotFoundError, ResourceCreationError -from labelbox.schema.internal.descriptor_file_creator import DescriptorFileCreator +from labelbox.schema.internal.descriptor_file_creator import ( + DescriptorFileCreator, +) def test_dataset(client, rand_gen): - # confirm dataset can be created name = rand_gen(str) dataset = client.create_dataset(name=name) @@ -76,8 +77,9 @@ def test_get_data_row_for_external_id(dataset, rand_gen, image_url): with pytest.raises(ResourceNotFoundError): data_row = dataset.data_row_for_external_id(external_id) - data_row = dataset.create_data_row(row_data=image_url, - external_id=external_id) + data_row = dataset.create_data_row( + row_data=image_url, external_id=external_id + ) found = dataset.data_row_for_external_id(external_id) assert found.uid == data_row.uid @@ -87,7 +89,8 @@ def test_get_data_row_for_external_id(dataset, rand_gen, image_url): assert len(dataset.data_rows_for_external_id(external_id)) == 2 task = dataset.create_data_rows( - [dict(row_data=image_url, external_id=external_id)]) + [dict(row_data=image_url, external_id=external_id)] + ) task.wait_till_done() assert len(dataset.data_rows_for_external_id(external_id)) == 3 @@ -102,41 +105,40 @@ def test_upload_video_file(dataset, sample_video: str) -> None: task = dataset.create_data_rows([sample_video, sample_video]) task.wait_till_done() - with open(sample_video, 'rb') as video_f: + with open(sample_video, "rb") as video_f: content_length = len(video_f.read()) for data_row in dataset.data_rows(): url = data_row.row_data response = requests.head(url, allow_redirects=True) - assert int(response.headers['Content-Length']) == content_length - assert response.headers['Content-Type'] == 'video/mp4' + assert int(response.headers["Content-Length"]) == content_length + assert response.headers["Content-Type"] == "video/mp4" def test_create_pdf(dataset): dataset.create_data_row( row_data={ - "pdfUrl": - "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-1.pdf", - "textLayerUrl": - "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-custom-text-layer.json" - }) - dataset.create_data_row(row_data={ - "pdfUrl": - "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-1.pdf", - "textLayerUrl": - "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-custom-text-layer.json" - }, - media_type="PDF") + "pdfUrl": "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-1.pdf", + "textLayerUrl": "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-custom-text-layer.json", + } + ) + dataset.create_data_row( + row_data={ + "pdfUrl": "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-1.pdf", + "textLayerUrl": "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-custom-text-layer.json", + }, + media_type="PDF", + ) with pytest.raises(ResourceCreationError): # Wrong media type - dataset.create_data_row(row_data={ - "pdfUrl": - "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-1.pdf", - "textLayerUrl": - "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-custom-text-layer.json" - }, - media_type="TEXT") + dataset.create_data_row( + row_data={ + "pdfUrl": "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-1.pdf", + "textLayerUrl": "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-custom-text-layer.json", + }, + media_type="TEXT", + ) def test_bulk_conversation(dataset, sample_bulk_conversation: list) -> None: @@ -152,17 +154,21 @@ def test_bulk_conversation(dataset, sample_bulk_conversation: list) -> None: def test_create_descriptor_file(dataset): import unittest.mock as mock + client = MagicMock() - with mock.patch.object(client, 'upload_data', - wraps=client.upload_data) as upload_data_spy: - DescriptorFileCreator(client).create_one(items=[{ - 'row_data': 'some text...' - }]) + with mock.patch.object( + client, "upload_data", wraps=client.upload_data + ) as upload_data_spy: + DescriptorFileCreator(client).create_one( + items=[{"row_data": "some text..."}] + ) upload_data_spy.assert_called() - call_args, call_kwargs = upload_data_spy.call_args_list[0][ - 0], upload_data_spy.call_args_list[0][1] + call_args, call_kwargs = ( + upload_data_spy.call_args_list[0][0], + upload_data_spy.call_args_list[0][1], + ) assert call_args == ('[{"row_data": "some text..."}]',) assert call_kwargs == { - 'content_type': 'application/json', - 'filename': 'json_import.json' + "content_type": "application/json", + "filename": "json_import.json", } diff --git a/libs/labelbox/tests/integration/test_delegated_access.py b/libs/labelbox/tests/integration/test_delegated_access.py index 1592319d2..0e6422b08 100644 --- a/libs/labelbox/tests/integration/test_delegated_access.py +++ b/libs/labelbox/tests/integration/test_delegated_access.py @@ -8,37 +8,39 @@ @pytest.mark.skip( - reason= - "Google credentials are being updated for this test, disabling till it's all sorted out" + reason="Google credentials are being updated for this test, disabling till it's all sorted out" +) +@pytest.mark.skipif( + not os.environ.get("DA_GCP_LABELBOX_API_KEY"), + reason="DA_GCP_LABELBOX_API_KEY not found", ) -@pytest.mark.skipif(not os.environ.get('DA_GCP_LABELBOX_API_KEY'), - reason="DA_GCP_LABELBOX_API_KEY not found") def test_default_integration(): """ This tests assumes the following: 1. gcp delegated access is configured to work with jtso-gcs-sdk-da-tests 2. the integration name is gcs sdk test bucket 3. This integration is the default - + Currently tests against: Org ID: cl269lvvj78b50zau34s4550z Email: jtso+gcp_sdk_tests@labelbox.com""" client = Client(api_key=os.environ.get("DA_GCP_LABELBOX_API_KEY")) ds = client.create_dataset(name="new_ds") dr = ds.create_data_row( - row_data= - "gs://jtso-gcs-sdk-da-tests/nikita-samokhin-D6QS6iv_CTY-unsplash.jpg") + row_data="gs://jtso-gcs-sdk-da-tests/nikita-samokhin-D6QS6iv_CTY-unsplash.jpg" + ) assert requests.get(dr.row_data).status_code == 200 assert ds.iam_integration().name == "gcs sdk test bucket" ds.delete() @pytest.mark.skip( - reason= - "Google credentials are being updated for this test, disabling till it's all sorted out" + reason="Google credentials are being updated for this test, disabling till it's all sorted out" +) +@pytest.mark.skipif( + not os.environ.get("DA_GCP_LABELBOX_API_KEY"), + reason="DA_GCP_LABELBOX_API_KEY not found", ) -@pytest.mark.skipif(not os.environ.get("DA_GCP_LABELBOX_API_KEY"), - reason="DA_GCP_LABELBOX_API_KEY not found") def test_non_default_integration(): """ This tests assumes the following: @@ -52,14 +54,13 @@ def test_non_default_integration(): client = Client(api_key=os.environ.get("DA_GCP_LABELBOX_API_KEY")) integrations = client.get_organization().get_iam_integrations() integration = [ - inte for inte in integrations if 'aws-da-test-bucket' in inte.name + inte for inte in integrations if "aws-da-test-bucket" in inte.name ][0] assert integration.valid ds = client.create_dataset(iam_integration=integration, name="new_ds") assert ds.iam_integration().name == "aws-da-test-bucket" dr = ds.create_data_row( - row_data= - "https://jtso-aws-da-sdk-tests.s3.us-east-2.amazonaws.com/adrian-yu-qkN4D3Rf1gw-unsplash.jpg" + row_data="https://jtso-aws-da-sdk-tests.s3.us-east-2.amazonaws.com/adrian-yu-qkN4D3Rf1gw-unsplash.jpg" ) assert requests.get(dr.row_data).status_code == 200 ds.delete() @@ -81,15 +82,16 @@ def test_no_default_integration(client): @pytest.mark.skip( - reason= - "Google credentials are being updated for this test, disabling till it's all sorted out" + reason="Google credentials are being updated for this test, disabling till it's all sorted out" +) +@pytest.mark.skipif( + not os.environ.get("DA_GCP_LABELBOX_API_KEY"), + reason="DA_GCP_LABELBOX_API_KEY not found", ) -@pytest.mark.skipif(not os.environ.get("DA_GCP_LABELBOX_API_KEY"), - reason="DA_GCP_LABELBOX_API_KEY not found") def test_add_integration_from_object(): """ This test is based on test_non_default_integration() and assumes the following: - + 1. aws delegated access is configured to work with lbox-test-bucket 2. an integration called aws is available to the org @@ -102,11 +104,14 @@ def test_add_integration_from_object(): # Prepare dataset with no integration integration = [ - integration for integration - in integrations - if 'aws-da-test-bucket' in integration.name][0] + integration + for integration in integrations + if "aws-da-test-bucket" in integration.name + ][0] - ds = client.create_dataset(iam_integration=None, name=f"integration_add_obj-{uuid.uuid4()}") + ds = client.create_dataset( + iam_integration=None, name=f"integration_add_obj-{uuid.uuid4()}" + ) # Test set integration with object new_integration = ds.add_iam_integration(integration) @@ -115,16 +120,18 @@ def test_add_integration_from_object(): # Cleaning ds.delete() + @pytest.mark.skip( - reason= - "Google credentials are being updated for this test, disabling till it's all sorted out" + reason="Google credentials are being updated for this test, disabling till it's all sorted out" +) +@pytest.mark.skipif( + not os.environ.get("DA_GCP_LABELBOX_API_KEY"), + reason="DA_GCP_LABELBOX_API_KEY not found", ) -@pytest.mark.skipif(not os.environ.get("DA_GCP_LABELBOX_API_KEY"), - reason="DA_GCP_LABELBOX_API_KEY not found") def test_add_integration_from_uid(): """ This test is based on test_non_default_integration() and assumes the following: - + 1. aws delegated access is configured to work with lbox-test-bucket 2. an integration called aws is available to the org @@ -137,34 +144,40 @@ def test_add_integration_from_uid(): # Prepare dataset with no integration integration = [ - integration for integration - in integrations - if 'aws-da-test-bucket' in integration.name][0] + integration + for integration in integrations + if "aws-da-test-bucket" in integration.name + ][0] - ds = client.create_dataset(iam_integration=None, name=f"integration_add_id-{uuid.uuid4()}") + ds = client.create_dataset( + iam_integration=None, name=f"integration_add_id-{uuid.uuid4()}" + ) # Test set integration with integration id integration_id = [ - integration.uid for integration - in integrations - if 'aws-da-test-bucket' in integration.name][0] - + integration.uid + for integration in integrations + if "aws-da-test-bucket" in integration.name + ][0] + new_integration = ds.add_iam_integration(integration_id) assert new_integration == integration # Cleaning ds.delete() + @pytest.mark.skip( - reason= - "Google credentials are being updated for this test, disabling till it's all sorted out" + reason="Google credentials are being updated for this test, disabling till it's all sorted out" +) +@pytest.mark.skipif( + not os.environ.get("DA_GCP_LABELBOX_API_KEY"), + reason="DA_GCP_LABELBOX_API_KEY not found", ) -@pytest.mark.skipif(not os.environ.get("DA_GCP_LABELBOX_API_KEY"), - reason="DA_GCP_LABELBOX_API_KEY not found") def test_integration_remove(): """ This test is based on test_non_default_integration() and assumes the following: - + 1. aws delegated access is configured to work with lbox-test-bucket 2. an integration called aws is available to the org @@ -177,15 +190,18 @@ def test_integration_remove(): # Prepare dataset with an existing integration integration = [ - integration for integration - in integrations - if 'aws-da-test-bucket' in integration.name][0] + integration + for integration in integrations + if "aws-da-test-bucket" in integration.name + ][0] - ds = client.create_dataset(iam_integration=integration, name=f"integration_remove-{uuid.uuid4()}") + ds = client.create_dataset( + iam_integration=integration, name=f"integration_remove-{uuid.uuid4()}" + ) # Test unset integration ds.remove_iam_integration() assert ds.iam_integration() is None # Cleaning - ds.delete() \ No newline at end of file + ds.delete() diff --git a/libs/labelbox/tests/integration/test_embedding.py b/libs/labelbox/tests/integration/test_embedding.py index 541b6d980..1b54ab81c 100644 --- a/libs/labelbox/tests/integration/test_embedding.py +++ b/libs/labelbox/tests/integration/test_embedding.py @@ -27,9 +27,10 @@ def test_get_embedding_by_name_not_found(client: Client): client.get_embedding_by_name("does-not-exist") -@pytest.mark.parametrize('data_rows', [10], indirect=True) -def test_import_vectors_from_file(data_rows: List[DataRow], - embedding: Embedding): +@pytest.mark.parametrize("data_rows", [10], indirect=True) +def test_import_vectors_from_file( + data_rows: List[DataRow], embedding: Embedding +): vector = [random.uniform(1.0, 2.0) for _ in range(embedding.dims)] event = threading.Event() @@ -38,10 +39,7 @@ def callback(_: Dict[str, Any]): with NamedTemporaryFile(mode="w+") as fp: lines = [ - json.dumps({ - "id": dr.uid, - "vector": vector - }) for dr in data_rows + json.dumps({"id": dr.uid, "vector": vector}) for dr in data_rows ] fp.writelines(lines) fp.flush() @@ -54,10 +52,9 @@ def test_get_imported_vector_count(dataset: Dataset, embedding: Embedding): assert embedding.get_imported_vector_count() == 0 vector = [random.uniform(1.0, 2.0) for _ in range(embedding.dims)] - dataset.create_data_row(row_data="foo", - embeddings=[{ - "embedding_id": embedding.id, - "vector": vector - }]) + dataset.create_data_row( + row_data="foo", + embeddings=[{"embedding_id": embedding.id, "vector": vector}], + ) assert embedding.get_imported_vector_count() == 1 diff --git a/libs/labelbox/tests/integration/test_ephemeral.py b/libs/labelbox/tests/integration/test_ephemeral.py index 6ebcf61c6..a23572fdf 100644 --- a/libs/labelbox/tests/integration/test_ephemeral.py +++ b/libs/labelbox/tests/integration/test_ephemeral.py @@ -2,8 +2,10 @@ import pytest -@pytest.mark.skipif(not os.environ.get('LABELBOX_TEST_ENVIRON') == 'ephemeral', - reason='This test only runs in EPHEMERAL environment') +@pytest.mark.skipif( + not os.environ.get("LABELBOX_TEST_ENVIRON") == "ephemeral", + reason="This test only runs in EPHEMERAL environment", +) def test_org_and_user_setup(client, ephmeral_client): assert type(client) == ephmeral_client assert client.admin_client @@ -15,7 +17,9 @@ def test_org_and_user_setup(client, ephmeral_client): assert user -@pytest.mark.skipif(os.environ.get('LABELBOX_TEST_ENVIRON') == 'ephemeral', - reason='This test does not run in EPHEMERAL environment') +@pytest.mark.skipif( + os.environ.get("LABELBOX_TEST_ENVIRON") == "ephemeral", + reason="This test does not run in EPHEMERAL environment", +) def test_integration_client(client, integration_client): assert type(client) == integration_client diff --git a/libs/labelbox/tests/integration/test_feature_schema.py b/libs/labelbox/tests/integration/test_feature_schema.py index 1dc25efc1..1dc940f08 100644 --- a/libs/labelbox/tests/integration/test_feature_schema.py +++ b/libs/labelbox/tests/integration/test_feature_schema.py @@ -12,36 +12,37 @@ def test_deletes_a_feature_schema(client): tool = client.upsert_feature_schema(point.asdict()) - assert client.delete_unused_feature_schema( - tool.normalized['featureSchemaId']) is None + assert ( + client.delete_unused_feature_schema(tool.normalized["featureSchemaId"]) + is None + ) def test_cant_delete_already_deleted_feature_schema(client): tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] + feature_schema_id = tool.normalized["featureSchemaId"] client.delete_unused_feature_schema(feature_schema_id) is None with pytest.raises( - Exception, - match= - "Failed to delete the feature schema, message: Feature schema is already deleted" + Exception, + match="Failed to delete the feature schema, message: Feature schema is already deleted", ): client.delete_unused_feature_schema(feature_schema_id) def test_cant_delete_feature_schema_with_ontology(client): tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] + feature_schema_id = tool.normalized["featureSchemaId"] ontology = client.create_ontology_from_feature_schemas( - name='ontology name', + name="ontology name", feature_schema_ids=[feature_schema_id], - media_type=MediaType.Image) + media_type=MediaType.Image, + ) with pytest.raises( - Exception, - match= - "Failed to delete the feature schema, message: Feature schema cannot be deleted because it is used in ontologies" + Exception, + match="Failed to delete the feature schema, message: Feature schema cannot be deleted because it is used in ontologies", ): client.delete_unused_feature_schema(feature_schema_id) @@ -51,29 +52,30 @@ def test_cant_delete_feature_schema_with_ontology(client): def test_throws_an_error_if_feature_schema_to_delete_doesnt_exist(client): with pytest.raises( - Exception, - match= - "Failed to delete the feature schema, message: Cannot find root schema node with feature schema id doesntexist" + Exception, + match="Failed to delete the feature schema, message: Cannot find root schema node with feature schema id doesntexist", ): client.delete_unused_feature_schema("doesntexist") def test_updates_a_feature_schema_title(client): tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] + feature_schema_id = tool.normalized["featureSchemaId"] new_title = "new title" updated_feature_schema = client.update_feature_schema_title( - feature_schema_id, new_title) + feature_schema_id, new_title + ) - assert updated_feature_schema.normalized['name'] == new_title + assert updated_feature_schema.normalized["name"] == new_title client.delete_unused_feature_schema(feature_schema_id) def test_throws_an_error_when_updating_a_feature_schema_with_empty_title( - client): + client, +): tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] + feature_schema_id = tool.normalized["featureSchemaId"] with pytest.raises(Exception): client.update_feature_schema_title(feature_schema_id, "") @@ -96,21 +98,23 @@ def test_updates_a_feature_schema(client, feature_schema): tool=Tool.Type.POINT, name="new name", color="#ff0000", - feature_schema_id=created_feature_schema.normalized['featureSchemaId'], + feature_schema_id=created_feature_schema.normalized["featureSchemaId"], ) updated_feature_schema = client.upsert_feature_schema( - tool_to_update.asdict()) + tool_to_update.asdict() + ) - assert updated_feature_schema.normalized['name'] == "new name" + assert updated_feature_schema.normalized["name"] == "new name" def test_does_not_include_used_feature_schema(client): tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] + feature_schema_id = tool.normalized["featureSchemaId"] ontology = client.create_ontology_from_feature_schemas( - name='ontology name', + name="ontology name", feature_schema_ids=[feature_schema_id], - media_type=MediaType.Image) + media_type=MediaType.Image, + ) unused_feature_schemas = client.get_unused_feature_schemas() assert feature_schema_id not in unused_feature_schemas diff --git a/libs/labelbox/tests/integration/test_filtering.py b/libs/labelbox/tests/integration/test_filtering.py index e751213cc..2e09ba573 100644 --- a/libs/labelbox/tests/integration/test_filtering.py +++ b/libs/labelbox/tests/integration/test_filtering.py @@ -30,8 +30,9 @@ def test_where(client, project_to_test_where): p_b_name = p_b.name def get(where=None): - date_where = Project.created_at >= min(p_a.created_at, p_b.created_at, - p_c.created_at) + date_where = Project.created_at >= min( + p_a.created_at, p_b.created_at, p_c.created_at + ) where = date_where if where is None else where & date_where return {p.uid for p in client.get_projects(where)} @@ -47,14 +48,16 @@ def get(where=None): ge_b = get(Project.name >= p_b_name) assert {p_b.uid, p_c.uid}.issubset(ge_b) and p_a.uid not in ge_b + def test_unsupported_where(client): with pytest.raises(InvalidQueryError): client.get_projects(where=(Project.name == "a") & (Project.name == "b")) # TODO support logical OR and NOT in where with pytest.raises(InvalidQueryError): - client.get_projects(where=(Project.name == "a") | - (Project.description == "b")) + client.get_projects( + where=(Project.name == "a") | (Project.description == "b") + ) with pytest.raises(InvalidQueryError): client.get_projects(where=~(Project.name == "a")) diff --git a/libs/labelbox/tests/integration/test_foundry.py b/libs/labelbox/tests/integration/test_foundry.py index 10d6be85b..83c4effc5 100644 --- a/libs/labelbox/tests/integration/test_foundry.py +++ b/libs/labelbox/tests/integration/test_foundry.py @@ -21,14 +21,15 @@ def foundry_client(client): @pytest.fixture() def text_data_row(dataset, random_str): global_key = "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-1.txt-{random_str}" - task = dataset.create_data_rows([{ - "row_data": - "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-1.txt", - "media_type": - "TEXT", - "global_key": - global_key - }]) + task = dataset.create_data_rows( + [ + { + "row_data": "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-1.txt", + "media_type": "TEXT", + "global_key": global_key, + } + ] + ) task.wait_till_done() dr = dataset.data_rows().get_one() yield dr @@ -38,32 +39,40 @@ def text_data_row(dataset, random_str): @pytest.fixture() def ontology(client, random_str): object_features = [ - lb.Tool(tool=lb.Tool.Type.BBOX, - name="text", - color="#ff0000", - classifications=[ - lb.Classification(class_type=lb.Classification.Type.TEXT, - name="value") - ]) + lb.Tool( + tool=lb.Tool.Type.BBOX, + name="text", + color="#ff0000", + classifications=[ + lb.Classification( + class_type=lb.Classification.Type.TEXT, name="value" + ) + ], + ) ] - ontology_builder = lb.OntologyBuilder(tools=object_features,) + ontology_builder = lb.OntologyBuilder( + tools=object_features, + ) ontology = client.create_ontology( f"Test ontology for tesseract model {random_str}", ontology_builder.asdict(), - media_type=lb.MediaType.Image) + media_type=lb.MediaType.Image, + ) return ontology @pytest.fixture() def unsaved_app(random_str, ontology): - return App(model_id=TEST_MODEL_ID, - name=f"Test App {random_str}", - description="Test App Description", - inference_params={"confidence": 0.2}, - class_to_schema_id={}, - ontology_id=ontology.uid) + return App( + model_id=TEST_MODEL_ID, + name=f"Test App {random_str}", + description="Test App Description", + inference_params={"confidence": 0.2}, + class_to_schema_id={}, + ontology_id=ontology.uid, + ) @pytest.fixture() @@ -75,15 +84,15 @@ def app(foundry_client, unsaved_app): def test_create_app(foundry_client, unsaved_app): app = foundry_client._create_app(unsaved_app) - retrieved_dict = app.model_dump(exclude={'id', 'created_by'}) - expected_dict = app.model_dump(exclude={'id', 'created_by'}) + retrieved_dict = app.model_dump(exclude={"id", "created_by"}) + expected_dict = app.model_dump(exclude={"id", "created_by"}) assert retrieved_dict == expected_dict def test_get_app(foundry_client, app): retrieved_app = foundry_client._get_app(app.id) - retrieved_dict = retrieved_app.model_dump(exclude={'created_by'}) - expected_dict = app.model_dump(exclude={'created_by'}) + retrieved_dict = retrieved_app.model_dump(exclude={"created_by"}) + expected_dict = app.model_dump(exclude={"created_by"}) assert retrieved_dict == expected_dict @@ -92,57 +101,65 @@ def test_get_app_with_invalid_id(foundry_client): foundry_client._get_app("invalid-id") -def test_run_foundry_app_with_data_row_id(foundry_client, data_row, app, - random_str): +def test_run_foundry_app_with_data_row_id( + foundry_client, data_row, app, random_str +): data_rows = lb.DataRowIds([data_row.uid]) task = foundry_client.run_app( model_run_name=f"test-app-with-datarow-id-{random_str}", data_rows=data_rows, - app_id=app.id) + app_id=app.id, + ) task.wait_till_done() - assert task.status == 'COMPLETE' + assert task.status == "COMPLETE" -def test_run_foundry_app_with_global_key(foundry_client, data_row, app, - random_str): +def test_run_foundry_app_with_global_key( + foundry_client, data_row, app, random_str +): data_rows = lb.GlobalKeys([data_row.global_key]) task = foundry_client.run_app( model_run_name=f"test-app-with-global-key-{random_str}", data_rows=data_rows, - app_id=app.id) + app_id=app.id, + ) task.wait_till_done() - assert task.status == 'COMPLETE' + assert task.status == "COMPLETE" -def test_run_foundry_app_returns_model_run_id(foundry_client, data_row, app, - random_str): +def test_run_foundry_app_returns_model_run_id( + foundry_client, data_row, app, random_str +): data_rows = lb.GlobalKeys([data_row.global_key]) task = foundry_client.run_app( model_run_name=f"test-app-with-global-key-{random_str}", data_rows=data_rows, - app_id=app.id) - model_run_id = task.metadata['modelRunId'] + app_id=app.id, + ) + model_run_id = task.metadata["modelRunId"] model_run = foundry_client.client.get_model_run(model_run_id) assert model_run.uid == model_run_id def test_run_foundry_with_invalid_data_row_id(foundry_client, app, random_str): - invalid_datarow_id = 'invalid-global-key' + invalid_datarow_id = "invalid-global-key" data_rows = lb.GlobalKeys([invalid_datarow_id]) with pytest.raises(lb.exceptions.LabelboxError) as exception: foundry_client.run_app( model_run_name=f"test-app-with-invalid-datarow-id-{random_str}", data_rows=data_rows, - app_id=app.id) + app_id=app.id, + ) assert invalid_datarow_id in exception.value def test_run_foundry_with_invalid_global_key(foundry_client, app, random_str): - invalid_global_key = 'invalid-global-key' + invalid_global_key = "invalid-global-key" data_rows = lb.GlobalKeys([invalid_global_key]) with pytest.raises(lb.exceptions.LabelboxError) as exception: foundry_client.run_app( model_run_name=f"test-app-with-invalid-global-key-{random_str}", data_rows=data_rows, - app_id=app.id) + app_id=app.id, + ) assert invalid_global_key in exception.value diff --git a/libs/labelbox/tests/integration/test_global_keys.py b/libs/labelbox/tests/integration/test_global_keys.py index 3fd3d84d9..9dc357812 100644 --- a/libs/labelbox/tests/integration/test_global_keys.py +++ b/libs/labelbox/tests/integration/test_global_keys.py @@ -14,38 +14,29 @@ def test_assign_global_keys_to_data_rows(client, dataset, image_url): gk_1 = str(uuid.uuid4()) gk_2 = str(uuid.uuid4()) - assignment_inputs = [{ - "data_row_id": dr_1.uid, - "global_key": gk_1 - }, { - "data_row_id": dr_2.uid, - "global_key": gk_2 - }] + assignment_inputs = [ + {"data_row_id": dr_1.uid, "global_key": gk_1}, + {"data_row_id": dr_2.uid, "global_key": gk_2}, + ] res = client.assign_global_keys_to_data_rows(assignment_inputs) - assert res['status'] == "SUCCESS" - assert res['errors'] == [] + assert res["status"] == "SUCCESS" + assert res["errors"] == [] - assert len(res['results']) == 2 - for r in res['results']: - del r['sanitized'] - assert res['results'] == assignment_inputs + assert len(res["results"]) == 2 + for r in res["results"]: + del r["sanitized"] + assert res["results"] == assignment_inputs def test_assign_global_keys_to_data_rows_validation_error(client): - assignment_inputs = [{ - "data_row_id": "test uid", - "wrong_key": "gk 1" - }, { - "data_row_id": "test uid 2", - "global_key": "gk 2" - }, { - "wrong_key": "test uid 3", - "global_key": "gk 3" - }, { - "data_row_id": "test uid 4" - }, { - "global_key": "gk 5" - }, {}] + assignment_inputs = [ + {"data_row_id": "test uid", "wrong_key": "gk 1"}, + {"data_row_id": "test uid 2", "global_key": "gk 2"}, + {"wrong_key": "test uid 3", "global_key": "gk 3"}, + {"data_row_id": "test uid 4"}, + {"global_key": "gk 5"}, + {}, + ] with pytest.raises(ValueError) as excinfo: client.assign_global_keys_to_data_rows(assignment_inputs) e = """[{'data_row_id': 'test uid', 'wrong_key': 'gk 1'}, {'wrong_key': 'test uid 3', 'global_key': 'gk 3'}, {'data_row_id': 'test uid 4'}, {'global_key': 'gk 5'}, {}]""" @@ -58,124 +49,123 @@ def test_assign_same_global_keys_to_data_rows(client, dataset, image_url): gk_1 = str(uuid.uuid4()) - assignment_inputs = [{ - "data_row_id": dr_1.uid, - "global_key": gk_1 - }, { - "data_row_id": dr_2.uid, - "global_key": gk_1 - }] + assignment_inputs = [ + {"data_row_id": dr_1.uid, "global_key": gk_1}, + {"data_row_id": dr_2.uid, "global_key": gk_1}, + ] res = client.assign_global_keys_to_data_rows(assignment_inputs) - assert res['status'] == "PARTIAL SUCCESS" - assert len(res['results']) == 1 - assert res['results'][0]['data_row_id'] == dr_1.uid - assert res['results'][0]['global_key'] == gk_1 + assert res["status"] == "PARTIAL SUCCESS" + assert len(res["results"]) == 1 + assert res["results"][0]["data_row_id"] == dr_1.uid + assert res["results"][0]["global_key"] == gk_1 - assert len(res['errors']) == 1 - assert res['errors'][0]['data_row_id'] == dr_2.uid - assert res['errors'][0]['global_key'] == gk_1 - assert res['errors'][0][ - 'error'] == "Invalid assignment. Either DataRow does not exist, or globalKey is invalid" + assert len(res["errors"]) == 1 + assert res["errors"][0]["data_row_id"] == dr_2.uid + assert res["errors"][0]["global_key"] == gk_1 + assert ( + res["errors"][0]["error"] + == "Invalid assignment. Either DataRow does not exist, or globalKey is invalid" + ) def test_long_global_key_validation(client, dataset, image_url): - long_global_key = 'x' * 201 + long_global_key = "x" * 201 dr_1 = dataset.create_data_row(row_data=image_url) dr_2 = dataset.create_data_row(row_data=image_url) gk_1 = str(uuid.uuid4()) gk_2 = long_global_key - assignment_inputs = [{ - "data_row_id": dr_1.uid, - "global_key": gk_1 - }, { - "data_row_id": dr_2.uid, - "global_key": gk_2 - }] + assignment_inputs = [ + {"data_row_id": dr_1.uid, "global_key": gk_1}, + {"data_row_id": dr_2.uid, "global_key": gk_2}, + ] res = client.assign_global_keys_to_data_rows(assignment_inputs) - assert len(res['results']) == 1 - assert len(res['errors']) == 1 - assert res['status'] == 'PARTIAL SUCCESS' - assert res['results'][0]['data_row_id'] == dr_1.uid - assert res['results'][0]['global_key'] == gk_1 - assert res['errors'][0]['data_row_id'] == dr_2.uid - assert res['errors'][0]['global_key'] == gk_2 - assert res['errors'][0][ - 'error'] == 'Invalid assignment. Either DataRow does not exist, or globalKey is invalid' + assert len(res["results"]) == 1 + assert len(res["errors"]) == 1 + assert res["status"] == "PARTIAL SUCCESS" + assert res["results"][0]["data_row_id"] == dr_1.uid + assert res["results"][0]["global_key"] == gk_1 + assert res["errors"][0]["data_row_id"] == dr_2.uid + assert res["errors"][0]["global_key"] == gk_2 + assert ( + res["errors"][0]["error"] + == "Invalid assignment. Either DataRow does not exist, or globalKey is invalid" + ) def test_global_key_with_whitespaces_validation(client, dataset, image_url): - data_row_items = [{ - "row_data": image_url, - }, { - "row_data": image_url, - }, { - "row_data": image_url, - }] + data_row_items = [ + { + "row_data": image_url, + }, + { + "row_data": image_url, + }, + { + "row_data": image_url, + }, + ] task = dataset.create_data_rows(data_row_items) task.wait_till_done() assert task.status == "COMPLETE" - dr_1_uid, dr_2_uid, dr_3_uid = [t['id'] for t in task.result] - - gk_1 = ' global key' - gk_2 = 'global key' - gk_3 = 'global key ' - - assignment_inputs = [{ - "data_row_id": dr_1_uid, - "global_key": gk_1 - }, { - "data_row_id": dr_2_uid, - "global_key": gk_2 - }, { - "data_row_id": dr_3_uid, - "global_key": gk_3 - }] + dr_1_uid, dr_2_uid, dr_3_uid = [t["id"] for t in task.result] + + gk_1 = " global key" + gk_2 = "global key" + gk_3 = "global key " + + assignment_inputs = [ + {"data_row_id": dr_1_uid, "global_key": gk_1}, + {"data_row_id": dr_2_uid, "global_key": gk_2}, + {"data_row_id": dr_3_uid, "global_key": gk_3}, + ] res = client.assign_global_keys_to_data_rows(assignment_inputs) - assert len(res['results']) == 0 - assert len(res['errors']) == 3 - assert res['status'] == 'FAILURE' - assign_errors_ids = set([e['data_row_id'] for e in res['errors']]) - assign_errors_gks = set([e['global_key'] for e in res['errors']]) - assign_errors_msgs = set([e['error'] for e in res['errors']]) + assert len(res["results"]) == 0 + assert len(res["errors"]) == 3 + assert res["status"] == "FAILURE" + assign_errors_ids = set([e["data_row_id"] for e in res["errors"]]) + assign_errors_gks = set([e["global_key"] for e in res["errors"]]) + assign_errors_msgs = set([e["error"] for e in res["errors"]]) assert assign_errors_ids == set([dr_1_uid, dr_2_uid, dr_3_uid]) assert assign_errors_gks == set([gk_1, gk_2, gk_3]) - assert assign_errors_msgs == set([ - 'Invalid assignment. Either DataRow does not exist, or globalKey is invalid', - 'Invalid assignment. Either DataRow does not exist, or globalKey is invalid', - 'Invalid assignment. Either DataRow does not exist, or globalKey is invalid' - ]) + assert assign_errors_msgs == set( + [ + "Invalid assignment. Either DataRow does not exist, or globalKey is invalid", + "Invalid assignment. Either DataRow does not exist, or globalKey is invalid", + "Invalid assignment. Either DataRow does not exist, or globalKey is invalid", + ] + ) def test_get_data_row_ids_for_global_keys(client, dataset, image_url): gk_1 = str(uuid.uuid4()) gk_2 = str(uuid.uuid4()) - dr_1 = dataset.create_data_row(row_data=image_url, - external_id="hello", - global_key=gk_1) - dr_2 = dataset.create_data_row(row_data=image_url, - external_id="world", - global_key=gk_2) + dr_1 = dataset.create_data_row( + row_data=image_url, external_id="hello", global_key=gk_1 + ) + dr_2 = dataset.create_data_row( + row_data=image_url, external_id="world", global_key=gk_2 + ) res = client.get_data_row_ids_for_global_keys([gk_1]) - assert res['status'] == "SUCCESS" - assert res['errors'] == [] - assert res['results'] == [dr_1.uid] + assert res["status"] == "SUCCESS" + assert res["errors"] == [] + assert res["results"] == [dr_1.uid] res = client.get_data_row_ids_for_global_keys([gk_2]) - assert res['status'] == "SUCCESS" - assert res['errors'] == [] - assert res['results'] == [dr_2.uid] + assert res["status"] == "SUCCESS" + assert res["errors"] == [] + assert res["results"] == [dr_2.uid] res = client.get_data_row_ids_for_global_keys([gk_1, gk_2]) - assert res['status'] == "SUCCESS" - assert res['errors'] == [] - assert res['results'] == [dr_1.uid, dr_2.uid] + assert res["status"] == "SUCCESS" + assert res["errors"] == [] + assert res["results"] == [dr_1.uid, dr_2.uid] def test_get_data_row_ids_for_invalid_global_keys(client, dataset, image_url): @@ -183,24 +173,24 @@ def test_get_data_row_ids_for_invalid_global_keys(client, dataset, image_url): gk_2 = str(uuid.uuid4()) dr_1 = dataset.create_data_row(row_data=image_url, external_id="hello") - dr_2 = dataset.create_data_row(row_data=image_url, - external_id="world", - global_key=gk_2) + dr_2 = dataset.create_data_row( + row_data=image_url, external_id="world", global_key=gk_2 + ) res = client.get_data_row_ids_for_global_keys([gk_1]) - assert res['status'] == "FAILURE" - assert len(res['errors']) == 1 - assert res['errors'][0]['error'] == "Data Row not found" - assert res['errors'][0]['global_key'] == gk_1 + assert res["status"] == "FAILURE" + assert len(res["errors"]) == 1 + assert res["errors"][0]["error"] == "Data Row not found" + assert res["errors"][0]["global_key"] == gk_1 res = client.get_data_row_ids_for_global_keys([gk_1, gk_2]) - assert res['status'] == "PARTIAL SUCCESS" + assert res["status"] == "PARTIAL SUCCESS" - assert len(res['errors']) == 1 - assert len(res['results']) == 2 + assert len(res["errors"]) == 1 + assert len(res["results"]) == 2 - assert res['errors'][0]['error'] == "Data Row not found" - assert res['errors'][0]['global_key'] == gk_1 + assert res["errors"][0]["error"] == "Data Row not found" + assert res["errors"][0]["global_key"] == gk_1 - assert res['results'][0] == '' - assert res['results'][1] == dr_2.uid + assert res["results"][0] == "" + assert res["results"][1] == dr_2.uid diff --git a/libs/labelbox/tests/integration/test_label.py b/libs/labelbox/tests/integration/test_label.py index c7221553e..1bd8a8276 100644 --- a/libs/labelbox/tests/integration/test_label.py +++ b/libs/labelbox/tests/integration/test_label.py @@ -29,10 +29,10 @@ def test_labels(configured_project_with_label): # TODO: Skipping this test in staging due to label not updating @pytest.mark.skipif( - condition=os.environ["LABELBOX_TEST_ENVIRON"] == "onprem" or - os.environ["LABELBOX_TEST_ENVIRON"] == "staging" or - os.environ["LABELBOX_TEST_ENVIRON"] == "local" or - os.environ["LABELBOX_TEST_ENVIRON"] == "custom", + condition=os.environ["LABELBOX_TEST_ENVIRON"] == "onprem" + or os.environ["LABELBOX_TEST_ENVIRON"] == "staging" + or os.environ["LABELBOX_TEST_ENVIRON"] == "local" + or os.environ["LABELBOX_TEST_ENVIRON"] == "custom", reason="does not work for onprem", ) def test_label_update(configured_project_with_label): @@ -82,8 +82,10 @@ def test_upsert_label_scores(configured_project_with_label, client: Client): label = next(project.labels()) - scores = client.upsert_label_feedback(label_id=label.uid, - feedback="That's a great label!", - scores={"overall": 5}) + scores = client.upsert_label_feedback( + label_id=label.uid, + feedback="That's a great label!", + scores={"overall": 5}, + ) assert len(scores) == 1 assert scores[0].score == 5 diff --git a/libs/labelbox/tests/integration/test_labeling_dashboard.py b/libs/labelbox/tests/integration/test_labeling_dashboard.py index 96d6af57f..97536e337 100644 --- a/libs/labelbox/tests/integration/test_labeling_dashboard.py +++ b/libs/labelbox/tests/integration/test_labeling_dashboard.py @@ -1,5 +1,21 @@ from datetime import datetime, timedelta -from labelbox.schema.search_filters import IntegerValue, RangeDateTimeOperatorWithSingleValue, RangeOperatorWithSingleValue, DateRange, RangeOperatorWithValue, DateRangeValue, DateValue, IdOperator, OperationType, OrganizationFilter, TaskCompletedCountFilter, WorkforceRequestedDateFilter, WorkforceRequestedDateRangeFilter, WorkspaceFilter, TaskRemainingCountFilter +from labelbox.schema.search_filters import ( + IntegerValue, + RangeDateTimeOperatorWithSingleValue, + RangeOperatorWithSingleValue, + DateRange, + RangeOperatorWithValue, + DateRangeValue, + DateValue, + IdOperator, + OperationType, + OrganizationFilter, + TaskCompletedCountFilter, + WorkforceRequestedDateFilter, + WorkforceRequestedDateRangeFilter, + WorkspaceFilter, + TaskRemainingCountFilter, +) def test_request_labeling_service_dashboard(requested_labeling_service): @@ -20,12 +36,14 @@ def test_request_labeling_service_dashboard_filters(requested_labeling_service): project, _ = requested_labeling_service organization = project.client.get_organization() - org_filter = OrganizationFilter(operator=IdOperator.Is, - values=[organization.uid]) + org_filter = OrganizationFilter( + operator=IdOperator.Is, values=[organization.uid] + ) try: project.client.get_labeling_service_dashboards( - search_query=[org_filter]).get_one() + search_query=[org_filter] + ).get_one() except Exception as e: assert False, f"An exception was raised: {e}" @@ -33,41 +51,55 @@ def test_request_labeling_service_dashboard_filters(requested_labeling_service): operation=OperationType.WorforceRequestedDate, value=DateValue( operator=RangeDateTimeOperatorWithSingleValue.GreaterThanOrEqual, - value=datetime.strptime("2024-01-01", "%Y-%m-%d"))) - year_from_now = (datetime.now() + timedelta(days=365)) + value=datetime.strptime("2024-01-01", "%Y-%m-%d"), + ), + ) + year_from_now = datetime.now() + timedelta(days=365) workforce_requested_filter_before = WorkforceRequestedDateFilter( operation=OperationType.WorforceRequestedDate, value=DateValue( operator=RangeDateTimeOperatorWithSingleValue.LessThanOrEqual, - value=year_from_now)) + value=year_from_now, + ), + ) try: - project.client.get_labeling_service_dashboards(search_query=[ - workforce_requested_filter_after, workforce_requested_filter_before - ]).get_one() + project.client.get_labeling_service_dashboards( + search_query=[ + workforce_requested_filter_after, + workforce_requested_filter_before, + ] + ).get_one() except Exception as e: assert False, f"An exception was raised: {e}" workforce_date_range_filter = WorkforceRequestedDateRangeFilter( operation=OperationType.WorforceRequestedDate, - value=DateRangeValue(operator=RangeOperatorWithValue.Between, - value=DateRange(min="2024-01-01T00:00:00-0800", - max=year_from_now))) + value=DateRangeValue( + operator=RangeOperatorWithValue.Between, + value=DateRange(min="2024-01-01T00:00:00-0800", max=year_from_now), + ), + ) try: project.client.get_labeling_service_dashboards( - search_query=[workforce_date_range_filter]).get_one() + search_query=[workforce_date_range_filter] + ).get_one() except Exception as e: assert False, f"An exception was raised: {e}" # with non existing data workspace_id = "clzzu4rme000008l42vnl4kre" - workspace_filter = WorkspaceFilter(operation=OperationType.Workspace, - operator=IdOperator.Is, - values=[workspace_id]) + workspace_filter = WorkspaceFilter( + operation=OperationType.Workspace, + operator=IdOperator.Is, + values=[workspace_id], + ) labeling_service_dashboard = [ - ld for ld in project.client.get_labeling_service_dashboards( - search_query=[workspace_filter]) + ld + for ld in project.client.get_labeling_service_dashboards( + search_query=[workspace_filter] + ) ] assert len(labeling_service_dashboard) == 0 assert labeling_service_dashboard == [] @@ -75,15 +107,19 @@ def test_request_labeling_service_dashboard_filters(requested_labeling_service): task_done_count_filter = TaskCompletedCountFilter( operation=OperationType.TaskCompletedCount, value=IntegerValue( - operator=RangeOperatorWithSingleValue.GreaterThanOrEqual, value=0)) + operator=RangeOperatorWithSingleValue.GreaterThanOrEqual, value=0 + ), + ) task_remaining_count_filter = TaskRemainingCountFilter( operation=OperationType.TaskRemainingCount, value=IntegerValue( - operator=RangeOperatorWithSingleValue.GreaterThanOrEqual, value=0)) + operator=RangeOperatorWithSingleValue.GreaterThanOrEqual, value=0 + ), + ) try: project.client.get_labeling_service_dashboards( - search_query=[task_done_count_filter, task_remaining_count_filter - ]).get_one() + search_query=[task_done_count_filter, task_remaining_count_filter] + ).get_one() except Exception as e: assert False, f"An exception was raised: {e}" diff --git a/libs/labelbox/tests/integration/test_labeling_frontend.py b/libs/labelbox/tests/integration/test_labeling_frontend.py index d13871372..d6ea1aac9 100644 --- a/libs/labelbox/tests/integration/test_labeling_frontend.py +++ b/libs/labelbox/tests/integration/test_labeling_frontend.py @@ -6,14 +6,16 @@ def test_get_labeling_frontends(client): filtered_frontends = list( - client.get_labeling_frontends(where=LabelingFrontend.name == 'Editor')) + client.get_labeling_frontends(where=LabelingFrontend.name == "Editor") + ) assert len(filtered_frontends) def test_labeling_frontend_connecting_to_project(project): client = project.client default_labeling_frontend = next( - client.get_labeling_frontends(where=LabelingFrontend.name == "Editor")) + client.get_labeling_frontends(where=LabelingFrontend.name == "Editor") + ) assert project.labeling_frontend() is None diff --git a/libs/labelbox/tests/integration/test_labeling_parameter_overrides.py b/libs/labelbox/tests/integration/test_labeling_parameter_overrides.py index 51c56353c..bd14040de 100644 --- a/libs/labelbox/tests/integration/test_labeling_parameter_overrides.py +++ b/libs/labelbox/tests/integration/test_labeling_parameter_overrides.py @@ -8,13 +8,20 @@ def test_labeling_parameter_overrides(consensus_project_with_batch): [project, _, data_rows] = consensus_project_with_batch init_labeling_parameter_overrides = list( - project.labeling_parameter_overrides()) + project.labeling_parameter_overrides() + ) assert len(init_labeling_parameter_overrides) == 3 - assert {o.number_of_labels for o in init_labeling_parameter_overrides - } == {1, 1, 1} + assert {o.number_of_labels for o in init_labeling_parameter_overrides} == { + 1, + 1, + 1, + } assert {o.priority for o in init_labeling_parameter_overrides} == {5, 5, 5} - assert {o.data_row().uid for o in init_labeling_parameter_overrides - } == {data_rows[0].uid, data_rows[1].uid, data_rows[2].uid} + assert {o.data_row().uid for o in init_labeling_parameter_overrides} == { + data_rows[0].uid, + data_rows[1].uid, + data_rows[2].uid, + } data = [(data_rows[0], 4, 2), (data_rows[1], 3)] success = project.set_labeling_parameter_overrides(data) @@ -28,8 +35,11 @@ def test_labeling_parameter_overrides(consensus_project_with_batch): for override in updated_overrides: assert isinstance(override.data_row(), DataRow) - data = [(UniqueId(data_rows[0].uid), 1, 2), (UniqueId(data_rows[1].uid), 2), - (UniqueId(data_rows[2].uid), 3)] + data = [ + (UniqueId(data_rows[0].uid), 1, 2), + (UniqueId(data_rows[1].uid), 2), + (UniqueId(data_rows[2].uid), 3), + ] success = project.set_labeling_parameter_overrides(data) assert success updated_overrides = list(project.labeling_parameter_overrides()) @@ -37,9 +47,11 @@ def test_labeling_parameter_overrides(consensus_project_with_batch): assert {o.number_of_labels for o in updated_overrides} == {1, 1, 1} assert {o.priority for o in updated_overrides} == {1, 2, 3} - data = [(GlobalKey(data_rows[0].global_key), 2, 2), - (GlobalKey(data_rows[1].global_key), 3, 3), - (GlobalKey(data_rows[2].global_key), 4)] + data = [ + (GlobalKey(data_rows[0].global_key), 2, 2), + (GlobalKey(data_rows[1].global_key), 3, 3), + (GlobalKey(data_rows[2].global_key), 4), + ] success = project.set_labeling_parameter_overrides(data) assert success updated_overrides = list(project.labeling_parameter_overrides()) @@ -50,21 +62,26 @@ def test_labeling_parameter_overrides(consensus_project_with_batch): with pytest.raises(TypeError) as exc_info: data = [(data_rows[2], "a_string", 3)] project.set_labeling_parameter_overrides(data) - assert str(exc_info.value) == \ - f"Priority must be an int. Found for data_row_identifier {data_rows[2].uid}" + assert ( + str(exc_info.value) + == f"Priority must be an int. Found for data_row_identifier {data_rows[2].uid}" + ) with pytest.raises(TypeError) as exc_info: data = [(data_rows[2].uid, 1)] project.set_labeling_parameter_overrides(data) - assert str(exc_info.value) == \ - f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found for data_row_identifier {data_rows[2].uid}" + assert ( + str(exc_info.value) + == f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found for data_row_identifier {data_rows[2].uid}" + ) def test_set_labeling_priority(consensus_project_with_batch): [project, _, data_rows] = consensus_project_with_batch init_labeling_parameter_overrides = list( - project.labeling_parameter_overrides()) + project.labeling_parameter_overrides() + ) assert len(init_labeling_parameter_overrides) == 3 assert {o.priority for o in init_labeling_parameter_overrides} == {5, 5, 5} diff --git a/libs/labelbox/tests/integration/test_labeling_service.py b/libs/labelbox/tests/integration/test_labeling_service.py index be0b8a6ee..09b5c24a1 100644 --- a/libs/labelbox/tests/integration/test_labeling_service.py +++ b/libs/labelbox/tests/integration/test_labeling_service.py @@ -15,8 +15,12 @@ def test_start_labeling_service(project): def test_request_labeling_service_moe_offline_project( - rand_gen, offline_chat_evaluation_project, chat_evaluation_ontology, - offline_conversational_data_row, model_config): + rand_gen, + offline_chat_evaluation_project, + chat_evaluation_ontology, + offline_conversational_data_row, + model_config, +): project = offline_chat_evaluation_project project.connect_ontology(chat_evaluation_ontology) @@ -25,43 +29,48 @@ def test_request_labeling_service_moe_offline_project( [offline_conversational_data_row.uid], # sample of data row objects ) - project.upsert_instructions('tests/integration/media/sample_pdf.pdf') + project.upsert_instructions("tests/integration/media/sample_pdf.pdf") labeling_service = project.get_labeling_service() labeling_service.request() - assert project.get_labeling_service_status( - ) == LabelingServiceStatus.Requested + assert ( + project.get_labeling_service_status() == LabelingServiceStatus.Requested + ) def test_request_labeling_service_moe_project( - rand_gen, live_chat_evaluation_project_with_new_dataset, - chat_evaluation_ontology, model_config): + rand_gen, + live_chat_evaluation_project_with_new_dataset, + chat_evaluation_ontology, + model_config, +): project = live_chat_evaluation_project_with_new_dataset project.connect_ontology(chat_evaluation_ontology) - project.upsert_instructions('tests/integration/media/sample_pdf.pdf') + project.upsert_instructions("tests/integration/media/sample_pdf.pdf") labeling_service = project.get_labeling_service() with pytest.raises( - LabelboxError, - match= - '[{"errorType":"PROJECT_MODEL_CONFIG","errorMessage":"Project model config is not completed"}]' + LabelboxError, + match='[{"errorType":"PROJECT_MODEL_CONFIG","errorMessage":"Project model config is not completed"}]', ): labeling_service.request() project.add_model_config(model_config.uid) project.set_project_model_setup_complete() labeling_service.request() - assert project.get_labeling_service_status( - ) == LabelingServiceStatus.Requested + assert ( + project.get_labeling_service_status() == LabelingServiceStatus.Requested + ) def test_request_labeling_service_incomplete_requirements(ontology, project): - labeling_service = project.get_labeling_service( + labeling_service = ( + project.get_labeling_service() ) # project fixture is an Image type project - with pytest.raises(ResourceNotFoundError, - match="Associated ontology id could not be found" - ): # No labeling service by default + with pytest.raises( + ResourceNotFoundError, match="Associated ontology id could not be found" + ): # No labeling service by default labeling_service.request() project.connect_ontology(ontology) with pytest.raises(LabelboxError): diff --git a/libs/labelbox/tests/integration/test_legacy_project.py b/libs/labelbox/tests/integration/test_legacy_project.py index fbdf8b252..320a2191d 100644 --- a/libs/labelbox/tests/integration/test_legacy_project.py +++ b/libs/labelbox/tests/integration/test_legacy_project.py @@ -5,9 +5,8 @@ def test_project_dataset(client, rand_gen): with pytest.raises( - ValueError, - match= - "Dataset queue mode is deprecated. Please prefer Batch queue mode." + ValueError, + match="Dataset queue mode is deprecated. Please prefer Batch queue mode.", ): client.create_project( name=rand_gen(str), @@ -30,10 +29,12 @@ def test_project_auto_audit_parameters(client, rand_gen): def test_project_name_parameter(client, rand_gen): - with pytest.raises(ValueError, - match="project name must be a valid string."): + with pytest.raises( + ValueError, match="project name must be a valid string." + ): client.create_project() - with pytest.raises(ValueError, - match="project name must be a valid string."): + with pytest.raises( + ValueError, match="project name must be a valid string." + ): client.create_project(name=" ") diff --git a/libs/labelbox/tests/integration/test_model_config.py b/libs/labelbox/tests/integration/test_model_config.py index 960b096c6..7a060b917 100644 --- a/libs/labelbox/tests/integration/test_model_config.py +++ b/libs/labelbox/tests/integration/test_model_config.py @@ -1,16 +1,22 @@ import pytest from labelbox.exceptions import ResourceNotFoundError + def test_create_model_config(client, valid_model_id): - model_config = client.create_model_config("model_config", valid_model_id, {"param": "value"}) + model_config = client.create_model_config( + "model_config", valid_model_id, {"param": "value"} + ) assert model_config.inference_params["param"] == "value" assert model_config.name == "model_config" assert model_config.model_id == valid_model_id def test_delete_model_config(client, valid_model_id): - model_config_id = client.create_model_config("model_config", valid_model_id, {"param": "value"}) - assert(client.delete_model_config(model_config_id.uid)) + model_config_id = client.create_model_config( + "model_config", valid_model_id, {"param": "value"} + ) + assert client.delete_model_config(model_config_id.uid) + def test_delete_nonexistant_model_config(client): with pytest.raises(ResourceNotFoundError): diff --git a/libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py b/libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py index 2ff5607c3..bb1756afb 100644 --- a/libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py +++ b/libs/labelbox/tests/integration/test_offline_chat_evaluation_project.py @@ -1,11 +1,14 @@ import pytest -def test_create_offline_chat_evaluation_project(client, rand_gen, - offline_chat_evaluation_project, - chat_evaluation_ontology, - offline_conversational_data_row, - model_config): +def test_create_offline_chat_evaluation_project( + client, + rand_gen, + offline_chat_evaluation_project, + chat_evaluation_ontology, + offline_conversational_data_row, + model_config, +): project = offline_chat_evaluation_project assert project diff --git a/libs/labelbox/tests/integration/test_ontology.py b/libs/labelbox/tests/integration/test_ontology.py index 0b6b23e73..91ef74a39 100644 --- a/libs/labelbox/tests/integration/test_ontology.py +++ b/libs/labelbox/tests/integration/test_ontology.py @@ -9,57 +9,67 @@ def test_feature_schema_is_not_archived(client, ontology): - feature_schema_to_check = ontology.normalized['tools'][0] + feature_schema_to_check = ontology.normalized["tools"][0] result = client.is_feature_schema_archived( - ontology.uid, feature_schema_to_check['featureSchemaId']) + ontology.uid, feature_schema_to_check["featureSchemaId"] + ) assert result == False def test_feature_schema_is_archived(client, configured_project_with_label): project, _, _, label = configured_project_with_label ontology = project.ontology() - feature_schema_id = ontology.normalized['tools'][0]['featureSchemaId'] - result = client.delete_feature_schema_from_ontology(ontology.uid, - feature_schema_id) + feature_schema_id = ontology.normalized["tools"][0]["featureSchemaId"] + result = client.delete_feature_schema_from_ontology( + ontology.uid, feature_schema_id + ) assert result.archived == True and result.deleted == False - assert client.is_feature_schema_archived(ontology.uid, - feature_schema_id) == True + assert ( + client.is_feature_schema_archived(ontology.uid, feature_schema_id) + == True + ) def test_is_feature_schema_archived_for_non_existing_feature_schema( - client, ontology): + client, ontology +): with pytest.raises( - Exception, - match="The specified feature schema was not in the ontology"): - client.is_feature_schema_archived(ontology.uid, - 'invalid-feature-schema-id') + Exception, match="The specified feature schema was not in the ontology" + ): + client.is_feature_schema_archived( + ontology.uid, "invalid-feature-schema-id" + ) def test_is_feature_schema_archived_for_non_existing_ontology(client, ontology): - feature_schema_to_unarchive = ontology.normalized['tools'][0] + feature_schema_to_unarchive = ontology.normalized["tools"][0] with pytest.raises( - Exception, - match="Resource 'Ontology' not found for params: 'invalid-ontology'" + Exception, + match="Resource 'Ontology' not found for params: 'invalid-ontology'", ): client.is_feature_schema_archived( - 'invalid-ontology', feature_schema_to_unarchive['featureSchemaId']) + "invalid-ontology", feature_schema_to_unarchive["featureSchemaId"] + ) def test_delete_tool_feature_from_ontology(client, ontology): - feature_schema_to_delete = ontology.normalized['tools'][0] - assert len(ontology.normalized['tools']) == 2 + feature_schema_to_delete = ontology.normalized["tools"][0] + assert len(ontology.normalized["tools"]) == 2 result = client.delete_feature_schema_from_ontology( - ontology.uid, feature_schema_to_delete['featureSchemaId']) + ontology.uid, feature_schema_to_delete["featureSchemaId"] + ) assert result.deleted == True assert result.archived == False updatedOntology = client.get_ontology(ontology.uid) - assert len(updatedOntology.normalized['tools']) == 1 + assert len(updatedOntology.normalized["tools"]) == 1 -@pytest.mark.skip(reason="normalized ontology contains Relationship, " - "which is not finalized yet. introduce this back when" - "Relationship feature is complete and we introduce" - "a Relationship object to the ontology that we can parse") +@pytest.mark.skip( + reason="normalized ontology contains Relationship, " + "which is not finalized yet. introduce this back when" + "Relationship feature is complete and we introduce" + "a Relationship object to the ontology that we can parse" +) def test_from_project_ontology(project) -> None: o = OntologyBuilder.from_project(project) assert o.asdict() == project.ontology().normalized @@ -74,11 +84,12 @@ def test_from_project_ontology(project) -> None: def test_deletes_an_ontology(client): tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] + feature_schema_id = tool.normalized["featureSchemaId"] ontology = client.create_ontology_from_feature_schemas( - name='ontology name', + name="ontology name", feature_schema_ids=[feature_schema_id], - media_type=MediaType.Image) + media_type=MediaType.Image, + ) assert client.delete_unused_ontology(ontology.uid) is None @@ -86,22 +97,25 @@ def test_deletes_an_ontology(client): def test_cant_delete_an_ontology_with_project(client): - project = client.create_project(name="test project", - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) + project = client.create_project( + name="test project", + queue_mode=QueueMode.Batch, + media_type=MediaType.Image, + ) tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] + feature_schema_id = tool.normalized["featureSchemaId"] ontology = client.create_ontology_from_feature_schemas( - name='ontology name', + name="ontology name", feature_schema_ids=[feature_schema_id], - media_type=MediaType.Image) + media_type=MediaType.Image, + ) project.connect_ontology(ontology) with pytest.raises( - Exception, - match= - "Failed to delete the ontology, message: Cannot delete an ontology connected to a project. The ontology is connected to projects: " - + project.uid): + Exception, + match="Failed to delete the ontology, message: Cannot delete an ontology connected to a project. The ontology is connected to projects: " + + project.uid, + ): client.delete_unused_ontology(ontology.uid) project.delete() @@ -110,56 +124,72 @@ def test_cant_delete_an_ontology_with_project(client): def test_inserts_a_feature_schema_at_given_position(client): - tool1 = {'tool': 'polygon', 'name': 'tool1', 'color': 'blue'} - tool2 = {'tool': 'polygon', 'name': 'tool2', 'color': 'blue'} + tool1 = {"tool": "polygon", "name": "tool1", "color": "blue"} + tool2 = {"tool": "polygon", "name": "tool2", "color": "blue"} ontology_normalized_json = {"tools": [tool1, tool2], "classifications": []} - ontology = client.create_ontology(name="ontology", - normalized=ontology_normalized_json, - media_type=MediaType.Image) + ontology = client.create_ontology( + name="ontology", + normalized=ontology_normalized_json, + media_type=MediaType.Image, + ) created_feature_schema = client.upsert_feature_schema(point.asdict()) client.insert_feature_schema_into_ontology( - created_feature_schema.normalized['featureSchemaId'], ontology.uid, 1) + created_feature_schema.normalized["featureSchemaId"], ontology.uid, 1 + ) ontology = client.get_ontology(ontology.uid) - assert ontology.normalized['tools'][1][ - 'schemaNodeId'] == created_feature_schema.normalized['schemaNodeId'] + assert ( + ontology.normalized["tools"][1]["schemaNodeId"] + == created_feature_schema.normalized["schemaNodeId"] + ) client.delete_unused_ontology(ontology.uid) def test_moves_already_added_feature_schema_in_ontology(client): - tool1 = {'tool': 'polygon', 'name': 'tool1', 'color': 'blue'} + tool1 = {"tool": "polygon", "name": "tool1", "color": "blue"} ontology_normalized_json = {"tools": [tool1], "classifications": []} - ontology = client.create_ontology(name="ontology", - normalized=ontology_normalized_json, - media_type=MediaType.Image) + ontology = client.create_ontology( + name="ontology", + normalized=ontology_normalized_json, + media_type=MediaType.Image, + ) created_feature_schema = client.upsert_feature_schema(point.asdict()) - feature_schema_id = created_feature_schema.normalized['featureSchemaId'] - client.insert_feature_schema_into_ontology(feature_schema_id, ontology.uid, - 1) + feature_schema_id = created_feature_schema.normalized["featureSchemaId"] + client.insert_feature_schema_into_ontology( + feature_schema_id, ontology.uid, 1 + ) ontology = client.get_ontology(ontology.uid) - assert ontology.normalized['tools'][1][ - 'schemaNodeId'] == created_feature_schema.normalized['schemaNodeId'] - client.insert_feature_schema_into_ontology(feature_schema_id, ontology.uid, - 0) + assert ( + ontology.normalized["tools"][1]["schemaNodeId"] + == created_feature_schema.normalized["schemaNodeId"] + ) + client.insert_feature_schema_into_ontology( + feature_schema_id, ontology.uid, 0 + ) ontology = client.get_ontology(ontology.uid) - assert ontology.normalized['tools'][0][ - 'schemaNodeId'] == created_feature_schema.normalized['schemaNodeId'] + assert ( + ontology.normalized["tools"][0]["schemaNodeId"] + == created_feature_schema.normalized["schemaNodeId"] + ) client.delete_unused_ontology(ontology.uid) def test_does_not_include_used_ontologies(client): tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] + feature_schema_id = tool.normalized["featureSchemaId"] ontology_with_project = client.create_ontology_from_feature_schemas( - name='ontology name', + name="ontology name", feature_schema_ids=[feature_schema_id], - media_type=MediaType.Image) - project = client.create_project(name="test project", - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) + media_type=MediaType.Image, + ) + project = client.create_project( + name="test project", + queue_mode=QueueMode.Batch, + media_type=MediaType.Image, + ) project.connect_ontology(ontology_with_project) unused_ontologies = client.get_unused_ontologies() @@ -185,10 +215,10 @@ def name_for_read(rand_gen): @pytest.fixture def feature_schema_cat_normalized(name_for_read): yield { - 'tool': 'polygon', - 'name': name_for_read, - 'color': 'black', - 'classifications': [], + "tool": "polygon", + "name": name_for_read, + "color": "black", + "classifications": [], } @@ -199,26 +229,29 @@ def feature_schema_for_read(client, feature_schema_cat_normalized): client.delete_unused_feature_schema(feature_schema.uid) -def test_feature_schema_create_read(client, feature_schema_for_read, - name_for_read): +def test_feature_schema_create_read( + client, feature_schema_for_read, name_for_read +): created_feature_schema = feature_schema_for_read queried_feature_schema = client.get_feature_schema( - created_feature_schema.uid) + created_feature_schema.uid + ) for attr in Entity.FeatureSchema.fields(): - assert _get_attr_stringify_json(created_feature_schema, - attr) == _get_attr_stringify_json( - queried_feature_schema, attr) + assert _get_attr_stringify_json( + created_feature_schema, attr + ) == _get_attr_stringify_json(queried_feature_schema, attr) time.sleep(3) # Slight delay for searching queried_feature_schemas = list(client.get_feature_schemas(name_for_read)) - assert [feature_schema.name for feature_schema in queried_feature_schemas - ] == [name_for_read] + assert [ + feature_schema.name for feature_schema in queried_feature_schemas + ] == [name_for_read] queried_feature_schema = queried_feature_schemas[0] for attr in Entity.FeatureSchema.fields(): - assert _get_attr_stringify_json(created_feature_schema, - attr) == _get_attr_stringify_json( - queried_feature_schema, attr) + assert _get_attr_stringify_json( + created_feature_schema, attr + ) == _get_attr_stringify_json(queried_feature_schema, attr) def test_ontology_create_read( @@ -228,61 +261,67 @@ def test_ontology_create_read( ontology_name = f"test-ontology-{rand_gen(str)}" tool_name = f"test-ontology-tool-{rand_gen(str)}" feature_schema_cat_normalized = { - 'tool': 'polygon', - 'name': tool_name, - 'color': 'black', - 'classifications': [], + "tool": "polygon", + "name": tool_name, + "color": "black", + "classifications": [], } feature_schema = client.create_feature_schema(feature_schema_cat_normalized) created_ontology = client.create_ontology_from_feature_schemas( name=ontology_name, feature_schema_ids=[feature_schema.uid], - media_type=MediaType.Image) - tool_normalized = created_ontology.normalized['tools'][0] + media_type=MediaType.Image, + ) + tool_normalized = created_ontology.normalized["tools"][0] for k, v in feature_schema_cat_normalized.items(): assert tool_normalized[k] == v - assert tool_normalized['schemaNodeId'] is not None - assert tool_normalized['featureSchemaId'] == feature_schema.uid + assert tool_normalized["schemaNodeId"] is not None + assert tool_normalized["featureSchemaId"] == feature_schema.uid queried_ontology = client.get_ontology(created_ontology.uid) for attr in Entity.Ontology.fields(): - assert _get_attr_stringify_json(created_ontology, - attr) == _get_attr_stringify_json( - queried_ontology, attr) + assert _get_attr_stringify_json( + created_ontology, attr + ) == _get_attr_stringify_json(queried_ontology, attr) time.sleep(3) # Slight delay for searching queried_ontologies = list(client.get_ontologies(ontology_name)) assert [ontology.name for ontology in queried_ontologies] == [ontology_name] queried_ontology = queried_ontologies[0] for attr in Entity.Ontology.fields(): - assert _get_attr_stringify_json(created_ontology, - attr) == _get_attr_stringify_json( - queried_ontology, attr) + assert _get_attr_stringify_json( + created_ontology, attr + ) == _get_attr_stringify_json(queried_ontology, attr) def test_unarchive_feature_schema_node(client, ontology): - feature_schema_to_unarchive = ontology.normalized['tools'][0] + feature_schema_to_unarchive = ontology.normalized["tools"][0] result = client.unarchive_feature_schema_node( - ontology.uid, feature_schema_to_unarchive['featureSchemaId']) + ontology.uid, feature_schema_to_unarchive["featureSchemaId"] + ) assert result == None def test_unarchive_feature_schema_node_for_non_existing_feature_schema( - client, ontology): + client, ontology +): with pytest.raises( - Exception, - match= - "Failed to find feature schema node by id: invalid-feature-schema-id" + Exception, + match="Failed to find feature schema node by id: invalid-feature-schema-id", ): - client.unarchive_feature_schema_node(ontology.uid, - 'invalid-feature-schema-id') + client.unarchive_feature_schema_node( + ontology.uid, "invalid-feature-schema-id" + ) def test_unarchive_feature_schema_node_for_non_existing_ontology( - client, ontology): - feature_schema_to_unarchive = ontology.normalized['tools'][0] - with pytest.raises(Exception, - match="Failed to find ontology by id: invalid-ontology"): + client, ontology +): + feature_schema_to_unarchive = ontology.normalized["tools"][0] + with pytest.raises( + Exception, match="Failed to find ontology by id: invalid-ontology" + ): client.unarchive_feature_schema_node( - 'invalid-ontology', feature_schema_to_unarchive['featureSchemaId']) + "invalid-ontology", feature_schema_to_unarchive["featureSchemaId"] + ) diff --git a/libs/labelbox/tests/integration/test_project.py b/libs/labelbox/tests/integration/test_project.py index 7b63ee391..a38fa2b5d 100644 --- a/libs/labelbox/tests/integration/test_project.py +++ b/libs/labelbox/tests/integration/test_project.py @@ -71,7 +71,9 @@ def delete_tag(tag_id: str): id } } - """, {"tag_id": tag_id}) + """, + {"tag_id": tag_id}, + ) return res org = client.get_organization() @@ -89,7 +91,7 @@ def delete_tag(tag_id: str): tagA = client.get_organization().create_resource_tag(tag) assert tagA.text == textA - assert '#' + tagA.color == colorA + assert "#" + tagA.color == colorA assert tagA.uid is not None tags = org.get_resource_tags() @@ -98,7 +100,7 @@ def delete_tag(tag_id: str): tagB = client.get_organization().create_resource_tag(tagB) assert tagB.text == textB - assert '#' + tagB.color == colorB + assert "#" + tagB.color == colorB assert tagB.uid is not None tags = client.get_organization().get_resource_tags() @@ -107,7 +109,8 @@ def delete_tag(tag_id: str): assert lenB > lenA project_resource_tag = client.get_project( - p1.uid).update_project_resource_tags([str(tagA.uid)]) + p1.uid + ).update_project_resource_tags([str(tagA.uid)]) assert len(project_resource_tag) == 1 assert project_resource_tag[0].uid == tagA.uid @@ -136,75 +139,84 @@ def test_extend_reservations(project): project.extend_reservations("InvalidQueueType") -@pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", - reason="new mutation does not work for onprem") +@pytest.mark.skipif( + condition=os.environ["LABELBOX_TEST_ENVIRON"] == "onprem", + reason="new mutation does not work for onprem", +) def test_attach_instructions(client, project): with pytest.raises(ValueError) as execinfo: - project.upsert_instructions('tests/integration/media/sample_pdf.pdf') - assert str( - execinfo.value - ) == "Cannot attach instructions to a project that has not been set up." + project.upsert_instructions("tests/integration/media/sample_pdf.pdf") + assert ( + str(execinfo.value) + == "Cannot attach instructions to a project that has not been set up." + ) editor = list( - client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] + client.get_labeling_frontends(where=LabelingFrontend.name == "editor") + )[0] empty_ontology = {"tools": [], "classifications": []} project.setup(editor, empty_ontology) - project.upsert_instructions('tests/integration/media/sample_pdf.pdf') + project.upsert_instructions("tests/integration/media/sample_pdf.pdf") time.sleep(3) - assert project.ontology().normalized['projectInstructions'] is not None + assert project.ontology().normalized["projectInstructions"] is not None with pytest.raises(ValueError) as exc_info: - project.upsert_instructions('/tmp/file.invalid_file_extension') + project.upsert_instructions("/tmp/file.invalid_file_extension") assert "instructions_file must be a pdf or html file. Found" in str( - exc_info.value) + exc_info.value + ) -@pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", - reason="new mutation does not work for onprem") +@pytest.mark.skipif( + condition=os.environ["LABELBOX_TEST_ENVIRON"] == "onprem", + reason="new mutation does not work for onprem", +) def test_html_instructions(project_with_empty_ontology): - html_file_path = '/tmp/instructions.html' + html_file_path = "/tmp/instructions.html" sample_html_str = "" - with open(html_file_path, 'w') as file: + with open(html_file_path, "w") as file: file.write(sample_html_str) project_with_empty_ontology.upsert_instructions(html_file_path) updated_ontology = project_with_empty_ontology.ontology().normalized - instructions = updated_ontology.pop('projectInstructions') + instructions = updated_ontology.pop("projectInstructions") assert requests.get(instructions).text == sample_html_str -@pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", - reason="new mutation does not work for onprem") +@pytest.mark.skipif( + condition=os.environ["LABELBOX_TEST_ENVIRON"] == "onprem", + reason="new mutation does not work for onprem", +) def test_same_ontology_after_instructions( - configured_project_with_complex_ontology): + configured_project_with_complex_ontology, +): project, _ = configured_project_with_complex_ontology initial_ontology = project.ontology().normalized - project.upsert_instructions('tests/assets/loremipsum.pdf') + project.upsert_instructions("tests/assets/loremipsum.pdf") updated_ontology = project.ontology().normalized - instructions = updated_ontology.pop('projectInstructions') + instructions = updated_ontology.pop("projectInstructions") assert initial_ontology == updated_ontology assert instructions is not None def test_batches(project: Project, dataset: Dataset, image_url): - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image" - }, - ] * 2) + task = dataset.create_data_rows( + [ + {"row_data": image_url, "external_id": "my-image"}, + ] + * 2 + ) task.wait_till_done() export_task = dataset.export() export_task.wait_till_done() stream = export_task.get_buffered_stream() data_rows = [dr.json["data_row"]["id"] for dr in stream] - batch_one = f'batch one {uuid.uuid4()}' - batch_two = f'batch two {uuid.uuid4()}' + batch_one = f"batch one {uuid.uuid4()}" + batch_two = f"batch two {uuid.uuid4()}" project.create_batch(batch_one, [data_rows[0]]) project.create_batch(batch_two, [data_rows[1]]) @@ -212,19 +224,19 @@ def test_batches(project: Project, dataset: Dataset, image_url): assert names == {batch_one, batch_two} -@pytest.mark.parametrize('data_rows', [2], indirect=True) +@pytest.mark.parametrize("data_rows", [2], indirect=True) def test_create_batch_with_global_keys_sync(project: Project, data_rows): global_keys = [dr.global_key for dr in data_rows] - batch_name = f'batch {uuid.uuid4()}' + batch_name = f"batch {uuid.uuid4()}" batch = project.create_batch(batch_name, global_keys=global_keys) assert batch.size == len(set(data_rows)) -@pytest.mark.parametrize('data_rows', [2], indirect=True) +@pytest.mark.parametrize("data_rows", [2], indirect=True) def test_create_batch_with_global_keys_async(project: Project, data_rows): global_keys = [dr.global_key for dr in data_rows] - batch_name = f'batch {uuid.uuid4()}' + batch_name = f"batch {uuid.uuid4()}" batch = project._create_batch_async(batch_name, global_keys=global_keys) assert batch.size == len(set(data_rows)) @@ -243,28 +255,35 @@ def test_media_type(client, project: Project, rand_gen): for media_type in MediaType.get_supported_members(): # Exclude LLM media types for now, as they are not supported if MediaType[media_type] in [ - MediaType.LLMPromptCreation, - MediaType.LLMPromptResponseCreation, MediaType.LLM + MediaType.LLMPromptCreation, + MediaType.LLMPromptResponseCreation, + MediaType.LLM, ]: continue - project = client.create_project(name=rand_gen(str), - media_type=MediaType[media_type]) + project = client.create_project( + name=rand_gen(str), media_type=MediaType[media_type] + ) assert project.media_type == MediaType[media_type] project.delete() def test_queue_mode(client, rand_gen): - project = client.create_project(name=rand_gen(str)) # defaults to benchmark and consensus + project = client.create_project( + name=rand_gen(str) + ) # defaults to benchmark and consensus assert project.auto_audit_number_of_labels == 3 assert project.auto_audit_percentage == 0 - project = client.create_project(name=rand_gen(str), quality_modes=[QualityMode.Benchmark]) + project = client.create_project( + name=rand_gen(str), quality_modes=[QualityMode.Benchmark] + ) assert project.auto_audit_number_of_labels == 1 assert project.auto_audit_percentage == 1 project = client.create_project( - name=rand_gen(str), quality_modes=[QualityMode.Benchmark, QualityMode.Consensus] + name=rand_gen(str), + quality_modes=[QualityMode.Benchmark, QualityMode.Consensus], ) assert project.auto_audit_number_of_labels == 3 assert project.auto_audit_percentage == 0 @@ -282,14 +301,18 @@ def test_label_count(client, configured_batch_project_with_label): def test_clone(client, project, rand_gen): # cannot clone unknown project media type - project = client.create_project(name=rand_gen(str), - media_type=MediaType.Image) + project = client.create_project( + name=rand_gen(str), media_type=MediaType.Image + ) cloned_project = project.clone() assert cloned_project.description == project.description assert cloned_project.media_type == project.media_type assert cloned_project.queue_mode == project.queue_mode - assert cloned_project.auto_audit_number_of_labels == project.auto_audit_number_of_labels + assert ( + cloned_project.auto_audit_number_of_labels + == project.auto_audit_number_of_labels + ) assert cloned_project.auto_audit_percentage == project.auto_audit_percentage assert cloned_project.get_label_count() == 0 diff --git a/libs/labelbox/tests/integration/test_project_model_config.py b/libs/labelbox/tests/integration/test_project_model_config.py index 7b564b2af..2d783f62b 100644 --- a/libs/labelbox/tests/integration/test_project_model_config.py +++ b/libs/labelbox/tests/integration/test_project_model_config.py @@ -2,52 +2,67 @@ from labelbox.exceptions import ResourceNotFoundError -def test_add_single_model_config(live_chat_evaluation_project_with_new_dataset, - model_config): +def test_add_single_model_config( + live_chat_evaluation_project_with_new_dataset, model_config +): configured_project = live_chat_evaluation_project_with_new_dataset project_model_config_id = configured_project.add_model_config( - model_config.uid) + model_config.uid + ) - assert set(config.uid - for config in configured_project.project_model_configs()) == set( - [project_model_config_id]) + assert set( + config.uid for config in configured_project.project_model_configs() + ) == set([project_model_config_id]) assert configured_project.delete_project_model_config( - project_model_config_id) + project_model_config_id + ) -def test_add_multiple_model_config(client, rand_gen, - live_chat_evaluation_project_with_new_dataset, - model_config, valid_model_id): +def test_add_multiple_model_config( + client, + rand_gen, + live_chat_evaluation_project_with_new_dataset, + model_config, + valid_model_id, +): configured_project = live_chat_evaluation_project_with_new_dataset - second_model_config = client.create_model_config(rand_gen(str), - valid_model_id, - {"param": "value"}) + second_model_config = client.create_model_config( + rand_gen(str), valid_model_id, {"param": "value"} + ) first_project_model_config_id = configured_project.add_model_config( - model_config.uid) + model_config.uid + ) second_project_model_config_id = configured_project.add_model_config( - second_model_config.uid) + second_model_config.uid + ) expected_model_configs = set( - [first_project_model_config_id, second_project_model_config_id]) + [first_project_model_config_id, second_project_model_config_id] + ) - assert set( - config.uid for config in configured_project.project_model_configs() - ) == expected_model_configs + assert ( + set(config.uid for config in configured_project.project_model_configs()) + == expected_model_configs + ) for project_model_config_id in expected_model_configs: assert configured_project.delete_project_model_config( - project_model_config_id) + project_model_config_id + ) -def test_delete_project_model_config(live_chat_evaluation_project_with_new_dataset, - model_config): +def test_delete_project_model_config( + live_chat_evaluation_project_with_new_dataset, model_config +): configured_project = live_chat_evaluation_project_with_new_dataset assert configured_project.delete_project_model_config( - configured_project.add_model_config(model_config.uid)) + configured_project.add_model_config(model_config.uid) + ) assert not len(configured_project.project_model_configs()) def test_delete_nonexistant_project_model_config(configured_project): with pytest.raises(ResourceNotFoundError): configured_project.delete_project_model_config( - "nonexistant_project_model_config") + "nonexistant_project_model_config" + ) diff --git a/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py b/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py index d48514024..1c3e68c9a 100644 --- a/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py +++ b/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py @@ -4,24 +4,23 @@ def test_live_chat_evaluation_project( - live_chat_evaluation_project_with_new_dataset, model_config): - + live_chat_evaluation_project_with_new_dataset, model_config +): project = live_chat_evaluation_project_with_new_dataset project.set_project_model_setup_complete() assert bool(project.model_setup_complete) is True with pytest.raises( - expected_exception=LabelboxError, - match= - "Cannot create model config for project because model setup is complete" + expected_exception=LabelboxError, + match="Cannot create model config for project because model setup is complete", ): project.add_model_config(model_config.uid) def test_live_chat_evaluation_project_delete_cofig( - live_chat_evaluation_project_with_new_dataset, model_config): - + live_chat_evaluation_project_with_new_dataset, model_config +): project = live_chat_evaluation_project_with_new_dataset project_model_config_id = project.add_model_config(model_config.uid) assert project_model_config_id @@ -37,30 +36,27 @@ def test_live_chat_evaluation_project_delete_cofig( assert bool(project.model_setup_complete) is True with pytest.raises( - expected_exception=LabelboxError, - match= - "Cannot create model config for project because model setup is complete" + expected_exception=LabelboxError, + match="Cannot create model config for project because model setup is complete", ): project_model_config.delete() -def test_offline_chat_evaluation_project(offline_chat_evaluation_project, - model_config): - +def test_offline_chat_evaluation_project( + offline_chat_evaluation_project, model_config +): project = offline_chat_evaluation_project with pytest.raises( - expected_exception=OperationNotAllowedException, - match= - "Only live model chat evaluation projects can complete model setup" + expected_exception=OperationNotAllowedException, + match="Only live model chat evaluation projects can complete model setup", ): project.set_project_model_setup_complete() def test_any_other_project(project, model_config): with pytest.raises( - expected_exception=OperationNotAllowedException, - match= - "Only live model chat evaluation projects can complete model setup" + expected_exception=OperationNotAllowedException, + match="Only live model chat evaluation projects can complete model setup", ): project.set_project_model_setup_complete() diff --git a/libs/labelbox/tests/integration/test_project_setup.py b/libs/labelbox/tests/integration/test_project_setup.py index 8404b0e50..faadea228 100644 --- a/libs/labelbox/tests/integration/test_project_setup.py +++ b/libs/labelbox/tests/integration/test_project_setup.py @@ -9,16 +9,17 @@ def simple_ontology(): - classifications = [{ - "name": "test_ontology", - "instructions": "Which class is this?", - "type": "radio", - "options": [{ - "value": c, - "label": c - } for c in ["one", "two", "three"]], - "required": True, - }] + classifications = [ + { + "name": "test_ontology", + "instructions": "Which class is this?", + "type": "radio", + "options": [ + {"value": c, "label": c} for c in ["one", "two", "three"] + ], + "required": True, + } + ] return {"tools": [], "classifications": classifications} @@ -26,7 +27,8 @@ def simple_ontology(): def test_project_setup(project) -> None: client = project.client labeling_frontends = list( - client.get_labeling_frontends(where=LabelingFrontend.name == 'Editor')) + client.get_labeling_frontends(where=LabelingFrontend.name == "Editor") + ) assert len(labeling_frontends) labeling_frontend = labeling_frontends[0] @@ -64,12 +66,14 @@ def test_project_editor_setup(client, project, rand_gen): assert project.ontology().name == ontology_name # Make sure that setup only creates one ontology time.sleep(3) # Search takes a second - assert [ontology.name for ontology in client.get_ontologies(ontology_name) - ] == [ontology_name] + assert [ + ontology.name for ontology in client.get_ontologies(ontology_name) + ] == [ontology_name] def test_project_connect_ontology_cant_call_multiple_times( - client, project, rand_gen): + client, project, rand_gen +): ontology_name = f"test_project_editor_setup_ontology_name-{rand_gen(str)}" ontology = client.create_ontology(ontology_name, simple_ontology()) project.connect_ontology(ontology) diff --git a/libs/labelbox/tests/integration/test_prompt_response_generation_project.py b/libs/labelbox/tests/integration/test_prompt_response_generation_project.py index 20d42d92c..1373ee470 100644 --- a/libs/labelbox/tests/integration/test_prompt_response_generation_project.py +++ b/libs/labelbox/tests/integration/test_prompt_response_generation_project.py @@ -5,19 +5,25 @@ from labelbox.schema.ontology_kind import OntologyKind from labelbox.exceptions import MalformedQueryException + @pytest.mark.parametrize( "prompt_response_ontology, prompt_response_generation_project_with_new_dataset", [ (MediaType.LLMPromptCreation, MediaType.LLMPromptCreation), - (MediaType.LLMPromptResponseCreation, MediaType.LLMPromptResponseCreation), + ( + MediaType.LLMPromptResponseCreation, + MediaType.LLMPromptResponseCreation, + ), ], - indirect=True + indirect=True, ) def test_prompt_response_generation_ontology_project( - client, prompt_response_ontology, - prompt_response_generation_project_with_new_dataset, - response_data_row, rand_gen): - + client, + prompt_response_ontology, + prompt_response_generation_project_with_new_dataset, + response_data_row, + rand_gen, +): ontology = prompt_response_ontology assert ontology @@ -35,36 +41,41 @@ def test_prompt_response_generation_ontology_project( assert project.ontology().name == ontology.name with pytest.raises( - ValueError, - match="Cannot create batches for auto data generation projects"): + ValueError, + match="Cannot create batches for auto data generation projects", + ): project.create_batch( rand_gen(str), [response_data_row.uid], # sample of data row objects ) with pytest.raises( - ValueError, - match="Cannot create batches for auto data generation projects"): - with patch('labelbox.schema.project.MAX_SYNC_BATCH_ROW_COUNT', - new=0): # force to async - + ValueError, + match="Cannot create batches for auto data generation projects", + ): + with patch( + "labelbox.schema.project.MAX_SYNC_BATCH_ROW_COUNT", new=0 + ): # force to async project.create_batch( rand_gen(str), - [response_data_row.uid - ], # sample of data row objects + [response_data_row.uid], # sample of data row objects ) + @pytest.mark.parametrize( "prompt_response_ontology, prompt_response_generation_project_with_dataset_id", [ (MediaType.LLMPromptCreation, MediaType.LLMPromptCreation), - (MediaType.LLMPromptResponseCreation, MediaType.LLMPromptResponseCreation), + ( + MediaType.LLMPromptResponseCreation, + MediaType.LLMPromptResponseCreation, + ), ], - indirect=True + indirect=True, ) def test_prompt_response_generation_ontology_project_with_existing_dataset( - prompt_response_ontology, - prompt_response_generation_project_with_dataset_id): + prompt_response_ontology, prompt_response_generation_project_with_dataset_id +): ontology = prompt_response_ontology project = prompt_response_generation_project_with_dataset_id @@ -77,48 +88,55 @@ def test_prompt_response_generation_ontology_project_with_existing_dataset( @pytest.fixture def classification_json(): - classifications = [{ - 'featureSchemaId': None, - 'kind': 'Prompt', - 'minCharacters': 2, - 'maxCharacters': 10, - 'name': 'prompt text', - 'instructions': 'prompt text', - 'required': True, - 'schemaNodeId': None, - "scope": "global", - 'type': 'prompt', - 'options': [] - }, { - 'featureSchemaId': None, - 'kind': 'ResponseCheckboxQuestion', - 'name': 'response checklist', - 'instructions': 'response checklist', - 'options': [{'featureSchemaId': None, - 'kind': 'ResponseCheckboxOption', - 'label': 'response checklist option', - 'schemaNodeId': None, - 'position': 0, - 'value': 'option_1'}], - 'required': True, - 'schemaNodeId': None, - "scope": "global", - 'type': 'response-checklist' - }, { - 'featureSchemaId': None, - 'kind': 'ResponseText', - 'maxCharacters': 10, - 'minCharacters': 1, - 'name': 'response text', - 'instructions': 'response text', - 'required': True, - 'schemaNodeId': None, - "scope": "global", - 'type': 'response-text', - 'options': [] - } + classifications = [ + { + "featureSchemaId": None, + "kind": "Prompt", + "minCharacters": 2, + "maxCharacters": 10, + "name": "prompt text", + "instructions": "prompt text", + "required": True, + "schemaNodeId": None, + "scope": "global", + "type": "prompt", + "options": [], + }, + { + "featureSchemaId": None, + "kind": "ResponseCheckboxQuestion", + "name": "response checklist", + "instructions": "response checklist", + "options": [ + { + "featureSchemaId": None, + "kind": "ResponseCheckboxOption", + "label": "response checklist option", + "schemaNodeId": None, + "position": 0, + "value": "option_1", + } + ], + "required": True, + "schemaNodeId": None, + "scope": "global", + "type": "response-checklist", + }, + { + "featureSchemaId": None, + "kind": "ResponseText", + "maxCharacters": 10, + "minCharacters": 1, + "name": "response text", + "instructions": "response text", + "required": True, + "schemaNodeId": None, + "scope": "global", + "type": "response-text", + "options": [], + }, ] - + return classifications @@ -139,7 +157,7 @@ def ontology_from_feature_ids(client, features_from_json): ontology = client.create_ontology_from_feature_schemas( name="test-prompt_response_creation{rand_gen(str)}", feature_schema_ids=feature_ids, - media_type=MediaType.LLMPromptResponseCreation + media_type=MediaType.LLMPromptResponseCreation, ) yield ontology @@ -147,18 +165,22 @@ def ontology_from_feature_ids(client, features_from_json): client.delete_unused_ontology(ontology.uid) -def test_ontology_create_feature_schema(ontology_from_feature_ids, - features_from_json, classification_json): +def test_ontology_create_feature_schema( + ontology_from_feature_ids, features_from_json, classification_json +): created_ontology = ontology_from_feature_ids feature_schema_ids = {f.uid for f in features_from_json} - classifications_normalized = created_ontology.normalized['classifications'] + classifications_normalized = created_ontology.normalized["classifications"] classifications = classification_json for classification in classifications: generated_tool = next( - c for c in classifications_normalized if c['name'] == classification['name']) - assert generated_tool['schemaNodeId'] is not None - assert generated_tool['featureSchemaId'] in feature_schema_ids - assert generated_tool['type'] == classification['type'] - assert generated_tool['name'] == classification['name'] - assert generated_tool['required'] == classification['required'] + c + for c in classifications_normalized + if c["name"] == classification["name"] + ) + assert generated_tool["schemaNodeId"] is not None + assert generated_tool["featureSchemaId"] in feature_schema_ids + assert generated_tool["type"] == classification["type"] + assert generated_tool["name"] == classification["name"] + assert generated_tool["required"] == classification["required"] diff --git a/libs/labelbox/tests/integration/test_response_creation_project.py b/libs/labelbox/tests/integration/test_response_creation_project.py index 76ba12d54..d7f9a1e46 100644 --- a/libs/labelbox/tests/integration/test_response_creation_project.py +++ b/libs/labelbox/tests/integration/test_response_creation_project.py @@ -3,11 +3,17 @@ from labelbox.schema.ontology_kind import OntologyKind -@pytest.mark.parametrize("prompt_response_ontology", [OntologyKind.ResponseCreation], indirect=True) -def test_create_response_creation_project(client, rand_gen, - response_creation_project, - prompt_response_ontology, - response_data_row): + +@pytest.mark.parametrize( + "prompt_response_ontology", [OntologyKind.ResponseCreation], indirect=True +) +def test_create_response_creation_project( + client, + rand_gen, + response_creation_project, + prompt_response_ontology, + response_data_row, +): project: Project = response_creation_project assert project @@ -21,4 +27,4 @@ def test_create_response_creation_project(client, rand_gen, rand_gen(str), [response_data_row.uid], # sample of data row objects ) - assert batch \ No newline at end of file + assert batch diff --git a/libs/labelbox/tests/integration/test_send_to_annotate.py b/libs/labelbox/tests/integration/test_send_to_annotate.py index fd358324f..3ba4d13a5 100644 --- a/libs/labelbox/tests/integration/test_send_to_annotate.py +++ b/libs/labelbox/tests/integration/test_send_to_annotate.py @@ -1,11 +1,16 @@ from labelbox import UniqueIds, Project, Ontology, Client -from labelbox.schema.conflict_resolution_strategy import ConflictResolutionStrategy +from labelbox.schema.conflict_resolution_strategy import ( + ConflictResolutionStrategy, +) from typing import List def test_send_to_annotate_include_annotations( - client: Client, configured_batch_project_with_label: Project, - project_pack: List[Project], ontology: Ontology): + client: Client, + configured_batch_project_with_label: Project, + project_pack: List[Project], + ontology: Ontology, +): [source_project, _, data_row, _] = configured_batch_project_with_label destination_project: Project = project_pack[0] @@ -14,18 +19,22 @@ def test_send_to_annotate_include_annotations( # build an ontology mapping using the top level tools src_feature_schema_ids = list( - tool.feature_schema_id for tool in src_ontology.tools()) + tool.feature_schema_id for tool in src_ontology.tools() + ) dest_ontology = destination_project.ontology() dest_feature_schema_ids = list( - tool.feature_schema_id for tool in dest_ontology.tools()) + tool.feature_schema_id for tool in dest_ontology.tools() + ) # create a dictionary of feature schema id to itself - ontology_mapping = dict(zip(src_feature_schema_ids, - dest_feature_schema_ids)) + ontology_mapping = dict( + zip(src_feature_schema_ids, dest_feature_schema_ids) + ) try: queues = destination_project.task_queues() initial_review_task = next( - q for q in queues if q.name == "Initial review task") + q for q in queues if q.name == "Initial review task" + ) # Send the data row to the new project task = client.send_to_annotate_from_catalog( @@ -34,13 +43,11 @@ def test_send_to_annotate_include_annotations( batch_name="test-batch", data_rows=UniqueIds([data_row.uid]), params={ - "source_project_id": - source_project.uid, - "annotations_ontology_mapping": - ontology_mapping, - "override_existing_annotations_rule": - ConflictResolutionStrategy.OverrideWithAnnotations - }) + "source_project_id": source_project.uid, + "annotations_ontology_mapping": ontology_mapping, + "override_existing_annotations_rule": ConflictResolutionStrategy.OverrideWithAnnotations, + }, + ) task.wait_till_done() @@ -57,7 +64,7 @@ def test_send_to_annotate_include_annotations( assert destination_data_rows[0] == data_row.uid # Verify annotations were copied into the destination project - destination_project_labels = (list(destination_project.labels())) + destination_project_labels = list(destination_project.labels()) assert len(destination_project_labels) == 1 finally: destination_project.delete() diff --git a/libs/labelbox/tests/integration/test_task.py b/libs/labelbox/tests/integration/test_task.py index b0eac2fa1..da89e4bb0 100644 --- a/libs/labelbox/tests/integration/test_task.py +++ b/libs/labelbox/tests/integration/test_task.py @@ -9,42 +9,50 @@ def test_task_errors(dataset, image_url, snapshot): client = dataset.client - task = dataset.create_data_rows([ - { - DataRow.row_data: - image_url, - DataRow.metadata_fields: [ - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, - value='some msg'), - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, - value='some msg 2') - ] - }, - ]) + task = dataset.create_data_rows( + [ + { + DataRow.row_data: image_url, + DataRow.metadata_fields: [ + DataRowMetadataField( + schema_id=TEXT_SCHEMA_ID, value="some msg" + ), + DataRowMetadataField( + schema_id=TEXT_SCHEMA_ID, value="some msg 2" + ), + ], + }, + ] + ) assert task in client.get_user().created_tasks() task.wait_till_done() assert len(task.failed_data_rows) == 1 - assert "A schemaId can only be specified once per DataRow : [cko8s9r5v0001h2dk9elqdidh]" in task.failed_data_rows[ - 0]['message'] - assert len(task.failed_data_rows[0]['failedDataRows'][0]['metadata']) == 2 + assert ( + "A schemaId can only be specified once per DataRow : [cko8s9r5v0001h2dk9elqdidh]" + in task.failed_data_rows[0]["message"] + ) + assert len(task.failed_data_rows[0]["failedDataRows"][0]["metadata"]) == 2 dt = client.get_task_by_id(task.uid) assert dt.status == "COMPLETE" assert len(dt.errors) == 1 - assert dt.errors[0]['message'].startswith( - "A schemaId can only be specified once per DataRow") + assert dt.errors[0]["message"].startswith( + "A schemaId can only be specified once per DataRow" + ) assert dt.result is None def test_task_success_json(dataset, image_url, snapshot): client = dataset.client - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url, - }, - ]) + task = dataset.create_data_rows( + [ + { + DataRow.row_data: image_url, + }, + ] + ) assert task in client.get_user().created_tasks() task.wait_till_done() assert task.status == "COMPLETE" @@ -54,14 +62,16 @@ def test_task_success_json(dataset, image_url, snapshot): assert task.result_url is not None assert isinstance(task.result_url, str) task_result = task.result[0] - assert 'id' in task_result and isinstance(task_result['id'], str) - assert 'row_data' in task_result and isinstance(task_result['row_data'], - str) + assert "id" in task_result and isinstance(task_result["id"], str) + assert "row_data" in task_result and isinstance( + task_result["row_data"], str + ) snapshot.snapshot_dir = INTEGRATION_SNAPSHOT_DIRECTORY - task_result['id'] = 'DUMMY_ID' - task_result['row_data'] = 'https://dummy.url' - snapshot.assert_match(json.dumps(task_result), - 'test_task.test_task_success_json.json') + task_result["id"] = "DUMMY_ID" + task_result["row_data"] = "https://dummy.url" + snapshot.assert_match( + json.dumps(task_result), "test_task.test_task_success_json.json" + ) assert len(task.result) dt = client.get_task_by_id(task.uid) diff --git a/libs/labelbox/tests/integration/test_task_queue.py b/libs/labelbox/tests/integration/test_task_queue.py index 2a6ca45d8..835f67219 100644 --- a/libs/labelbox/tests/integration/test_task_queue.py +++ b/libs/labelbox/tests/integration/test_task_queue.py @@ -7,7 +7,8 @@ def test_get_task_queue(project: Project): task_queues = project.task_queues() assert len(task_queues) == 3 review_queue = next( - tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE") + tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE" + ) assert review_queue @@ -23,6 +24,7 @@ def test_get_overview_no_details(project: Project): assert isinstance(po.labeled, int) assert isinstance(po.total_data_rows, int) + def test_get_overview_with_details(project: Project): po = project.get_overview(details=True) @@ -37,20 +39,23 @@ def test_get_overview_with_details(project: Project): assert isinstance(po.labeled, int) assert isinstance(po.total_data_rows, int) + def _validate_moved(project, queue_name, data_row_count): timeout_seconds = 30 sleep_time = 2 while True: task_queues = project.task_queues() review_queue = next( - tq for tq in task_queues if tq.queue_type == queue_name) + tq for tq in task_queues if tq.queue_type == queue_name + ) if review_queue.data_row_count == data_row_count: break if timeout_seconds <= 0: raise AssertionError( - "Timed out expecting data_row_count of 1 in the review queue") + "Timed out expecting data_row_count of 1 in the review queue" + ) timeout_seconds -= sleep_time time.sleep(sleep_time) @@ -61,18 +66,23 @@ def test_move_to_task(configured_batch_project_with_label): task_queues = project.task_queues() review_queue = next( - tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE") + tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE" + ) project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid) _validate_moved(project, "MANUAL_REVIEW_QUEUE", 1) review_queue = next( - tq for tq in task_queues if tq.queue_type == "MANUAL_REWORK_QUEUE") - project.move_data_rows_to_task_queue(GlobalKeys([data_row.global_key]), - review_queue.uid) + tq for tq in task_queues if tq.queue_type == "MANUAL_REWORK_QUEUE" + ) + project.move_data_rows_to_task_queue( + GlobalKeys([data_row.global_key]), review_queue.uid + ) _validate_moved(project, "MANUAL_REWORK_QUEUE", 1) review_queue = next( - tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE") - project.move_data_rows_to_task_queue(UniqueIds([data_row.uid]), - review_queue.uid) + tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE" + ) + project.move_data_rows_to_task_queue( + UniqueIds([data_row.uid]), review_queue.uid + ) _validate_moved(project, "MANUAL_REVIEW_QUEUE", 1) diff --git a/libs/labelbox/tests/integration/test_user_and_org.py b/libs/labelbox/tests/integration/test_user_and_org.py index ca158527c..7bb72051f 100644 --- a/libs/labelbox/tests/integration/test_user_and_org.py +++ b/libs/labelbox/tests/integration/test_user_and_org.py @@ -20,4 +20,4 @@ def test_user_and_org_projects(client, project): org_project = org.projects(where=Project.uid == project.uid) assert user_project - assert org_project \ No newline at end of file + assert org_project diff --git a/libs/labelbox/tests/integration/test_user_management.py b/libs/labelbox/tests/integration/test_user_management.py index ca4328f51..cfdf3c566 100644 --- a/libs/labelbox/tests/integration/test_user_management.py +++ b/libs/labelbox/tests/integration/test_user_management.py @@ -8,14 +8,17 @@ @pytest.fixture def org_invite(client, organization, environ, queries): - role = client.get_roles()['LABELER'] + role = client.get_roles()["LABELER"] - dummy_email = "none+{}@labelbox.com".format("".join( - faker.random_letters(26))) + dummy_email = "none+{}@labelbox.com".format( + "".join(faker.random_letters(26)) + ) invite_limit = organization.invite_limit() if environ.value == "prod": - assert invite_limit.remaining > 0, "No invites available for the account associated with this key." + assert ( + invite_limit.remaining > 0 + ), "No invites available for the account associated with this key." elif environ.value != "staging": # Cannot run against local return @@ -31,26 +34,29 @@ def org_invite(client, organization, environ, queries): def project_role_1(client, project_pack): project_1, _ = project_pack roles = client.get_roles() - return ProjectRole(project=project_1, role=roles['LABELER']) + return ProjectRole(project=project_1, role=roles["LABELER"]) @pytest.fixture def project_role_2(client, project_pack): _, project_2 = project_pack roles = client.get_roles() - return ProjectRole(project=project_2, role=roles['REVIEWER']) + return ProjectRole(project=project_2, role=roles["REVIEWER"]) @pytest.fixture -def create_project_invite(client, organization, project_pack, queries, - project_role_1, project_role_2): +def create_project_invite( + client, organization, project_pack, queries, project_role_1, project_role_2 +): roles = client.get_roles() - dummy_email = "none+{}@labelbox.com".format("".join( - faker.random_letters(26))) + dummy_email = "none+{}@labelbox.com".format( + "".join(faker.random_letters(26)) + ) invite = organization.invite_user( dummy_email, - roles['NONE'], - project_roles=[project_role_1, project_role_2]) + roles["NONE"], + project_roles=[project_role_1, project_role_2], + ) yield invite @@ -59,10 +65,9 @@ def create_project_invite(client, organization, project_pack, queries, def test_org_invite(client, organization, environ, queries, org_invite): invite, invite_limit = org_invite - role = client.get_roles()['LABELER'] + role = client.get_roles()["LABELER"] if environ.value == "prod": - invite_limit_after = organization.invite_limit() # One user added assert invite_limit.remaining - invite_limit_after.remaining == 1 @@ -75,7 +80,8 @@ def test_org_invite(client, organization, environ, queries, org_invite): if outstanding_invite.uid == invite.uid: in_list = True org_role = outstanding_invite.organization_role_name.lower() - assert org_role == role.name.lower( + assert ( + org_role == role.name.lower() ), "Role should be labeler. Found {org_role} " assert in_list, "Invite not found" @@ -85,44 +91,67 @@ def test_cancel_invite( organization, queries, ): - role = client.get_roles()['LABELER'] - dummy_email = "none+{}@labelbox.com".format("".join( - faker.random_letters(26))) + role = client.get_roles()["LABELER"] + dummy_email = "none+{}@labelbox.com".format( + "".join(faker.random_letters(26)) + ) invite = organization.invite_user(dummy_email, role) queries.cancel_invite(client, invite.uid) outstanding_invites = [i.uid for i in queries.get_invites(client)] assert invite.uid not in outstanding_invites -def test_project_invite(client, organization, project_pack, queries, - create_project_invite, project_role_1, project_role_2): +def test_project_invite( + client, + organization, + project_pack, + queries, + create_project_invite, + project_role_1, + project_role_2, +): create_project_invite project_1, _ = project_pack roles = client.get_roles() project_invite = next(queries.get_project_invites(client, project_1.uid)) - assert set([(proj_invite.project.uid, proj_invite.role.uid) - for proj_invite in project_invite.project_roles - ]) == set([(proj_role.project.uid, proj_role.role.uid) - for proj_role in [project_role_1, project_role_2]]) - - assert set([(proj_invite.project.uid, proj_invite.role.uid) - for proj_invite in project_invite.project_roles - ]) == set([(proj_role.project.uid, proj_role.role.uid) - for proj_role in [project_role_1, project_role_2]]) + assert set( + [ + (proj_invite.project.uid, proj_invite.role.uid) + for proj_invite in project_invite.project_roles + ] + ) == set( + [ + (proj_role.project.uid, proj_role.role.uid) + for proj_role in [project_role_1, project_role_2] + ] + ) + + assert set( + [ + (proj_invite.project.uid, proj_invite.role.uid) + for proj_invite in project_invite.project_roles + ] + ) == set( + [ + (proj_role.project.uid, proj_role.role.uid) + for proj_role in [project_role_1, project_role_2] + ] + ) project_members = project_1.members() project_member = [ - member for member in project_members + member + for member in project_members if member.user().uid == client.get_user().uid ] assert len(project_member) == 1 project_member = project_member[0] - assert project_member.access_from == 'ORGANIZATION' - assert project_member.role().name.upper() == roles['ADMIN'].name.upper() + assert project_member.access_from == "ORGANIZATION" + assert project_member.role().name.upper() == roles["ADMIN"].name.upper() @pytest.mark.skip( @@ -131,8 +160,7 @@ def test_project_invite(client, organization, project_pack, queries, def test_member_management(client, organization, project, project_based_user): roles = client.get_roles() assert not len(list(project_based_user.projects())) - for role in [roles['LABELER'], roles['REVIEWER']]: - + for role in [roles["LABELER"], roles["REVIEWER"]]: project_based_user.upsert_project_role(project, role=role) members = project.members() is_member = False @@ -148,11 +176,14 @@ def test_member_management(client, organization, project, project_based_user): for member in project.members(): assert member.user().uid != project_based_user.uid - assert project_based_user.org_role().name.upper( - ) == roles['NONE'].name.upper() + assert ( + project_based_user.org_role().name.upper() == roles["NONE"].name.upper() + ) for role in [ - roles['TEAM_MANAGER'], roles['ADMIN'], roles['LABELER'], - roles['REVIEWER'] + roles["TEAM_MANAGER"], + roles["ADMIN"], + roles["LABELER"], + roles["REVIEWER"], ]: project_based_user.update_org_role(role) project_based_user.org_role().name.upper() == role.name.upper() diff --git a/libs/labelbox/tests/integration/test_webhook.py b/libs/labelbox/tests/integration/test_webhook.py index 25c8c667a..b93255c4e 100644 --- a/libs/labelbox/tests/integration/test_webhook.py +++ b/libs/labelbox/tests/integration/test_webhook.py @@ -25,19 +25,25 @@ def test_webhook_create_update(project, rand_gen): with pytest.raises(ValueError) as exc_info: webhook.update(status="invalid..") valid_webhook_statuses = {item.value for item in Webhook.Status} - assert str(exc_info.value) == \ - f"Value `invalid..` does not exist in supported values. Expected one of {valid_webhook_statuses}" + assert ( + str(exc_info.value) + == f"Value `invalid..` does not exist in supported values. Expected one of {valid_webhook_statuses}" + ) with pytest.raises(ValueError) as exc_info: webhook.update(topics=["invalid.."]) valid_webhook_topics = {item.value for item in Webhook.Topic} - assert str(exc_info.value) == \ - f"Value `invalid..` does not exist in supported values. Expected one of {valid_webhook_topics}" + assert ( + str(exc_info.value) + == f"Value `invalid..` does not exist in supported values. Expected one of {valid_webhook_topics}" + ) with pytest.raises(TypeError) as exc_info: webhook.update(topics="invalid..") - assert str(exc_info.value) == \ - "Topics must be List[Webhook.Topic]. Found `invalid..`" + assert ( + str(exc_info.value) + == "Topics must be List[Webhook.Topic]. Found `invalid..`" + ) webhook.delete() @@ -50,8 +56,7 @@ def test_webhook_create_with_no_secret(project, rand_gen): with pytest.raises(ValueError) as exc_info: Webhook.create(client, topics, url, secret, project) - assert str(exc_info.value) == \ - "Secret must be a non-empty string." + assert str(exc_info.value) == "Secret must be a non-empty string." def test_webhook_create_with_no_topics(project, rand_gen): @@ -62,8 +67,7 @@ def test_webhook_create_with_no_topics(project, rand_gen): with pytest.raises(ValueError) as exc_info: Webhook.create(client, topics, url, secret, project) - assert str(exc_info.value) == \ - "Topics must be a non-empty list." + assert str(exc_info.value) == "Topics must be a non-empty list." def test_webhook_create_with_no_url(project, rand_gen): @@ -74,5 +78,4 @@ def test_webhook_create_with_no_url(project, rand_gen): with pytest.raises(ValueError) as exc_info: Webhook.create(client, topics, url, secret, project) - assert str(exc_info.value) == \ - "URL must be a non-empty string." + assert str(exc_info.value) == "URL must be a non-empty string." diff --git a/libs/labelbox/tests/unit/conftest.py b/libs/labelbox/tests/unit/conftest.py index 0e8de8185..603fa9908 100644 --- a/libs/labelbox/tests/unit/conftest.py +++ b/libs/labelbox/tests/unit/conftest.py @@ -6,40 +6,25 @@ def ndjson_content(): line = """{"uuid": "9fd9a92e-2560-4e77-81d4-b2e955800092", "schemaId": "ckaeasyfk004y0y7wyye5epgu", "dataRow": {"id": "ck7kftpan8ir008910yf07r9c"}, "bbox": {"top": 48, "left": 58, "height": 865, "width": 1512}} {"uuid": "29b878f3-c2b4-4dbf-9f22-a795f0720125", "schemaId": "ckapgvrl7007q0y7ujkjkaaxt", "dataRow": {"id": "ck7kftpan8ir008910yf07r9c"}, "polygon": [{"x": 147.692, "y": 118.154}, {"x": 142.769, "y": 404.923}, {"x": 57.846, "y": 318.769}, {"x": 28.308, "y": 169.846}]}""" - expected_objects = [{ - 'uuid': '9fd9a92e-2560-4e77-81d4-b2e955800092', - 'schemaId': 'ckaeasyfk004y0y7wyye5epgu', - 'dataRow': { - 'id': 'ck7kftpan8ir008910yf07r9c' + expected_objects = [ + { + "uuid": "9fd9a92e-2560-4e77-81d4-b2e955800092", + "schemaId": "ckaeasyfk004y0y7wyye5epgu", + "dataRow": {"id": "ck7kftpan8ir008910yf07r9c"}, + "bbox": {"top": 48, "left": 58, "height": 865, "width": 1512}, }, - 'bbox': { - 'top': 48, - 'left': 58, - 'height': 865, - 'width': 1512 - } - }, { - 'uuid': - '29b878f3-c2b4-4dbf-9f22-a795f0720125', - 'schemaId': - 'ckapgvrl7007q0y7ujkjkaaxt', - 'dataRow': { - 'id': 'ck7kftpan8ir008910yf07r9c' + { + "uuid": "29b878f3-c2b4-4dbf-9f22-a795f0720125", + "schemaId": "ckapgvrl7007q0y7ujkjkaaxt", + "dataRow": {"id": "ck7kftpan8ir008910yf07r9c"}, + "polygon": [ + {"x": 147.692, "y": 118.154}, + {"x": 142.769, "y": 404.923}, + {"x": 57.846, "y": 318.769}, + {"x": 28.308, "y": 169.846}, + ], }, - 'polygon': [{ - 'x': 147.692, - 'y': 118.154 - }, { - 'x': 142.769, - 'y': 404.923 - }, { - 'x': 57.846, - 'y': 318.769 - }, { - 'x': 28.308, - 'y': 169.846 - }] - }] + ] return line, expected_objects @@ -47,65 +32,55 @@ def ndjson_content(): @pytest.fixture def ndjson_content_with_nonascii_and_line_breaks(): line = '{"id": "2489651127", "type": "PushEvent", "actor": {"id": 1459915, "login": "xtuaok", "gravatar_id": "", "url": "https://api.github.com/users/xtuaok", "avatar_url": "https://avatars.githubusercontent.com/u/1459915?"}, "repo": {"id": 6719841, "name": "xtuaok/twitter_track_following", "url": "https://api.github.com/repos/xtuaok/twitter_track_following"}, "payload": {"push_id": 536864008, "size": 1, "distinct_size": 1, "ref": "refs/heads/xtuaok", "head": "afb8afe306c7893d93d383a06e4d9df53b41bf47", "before": "4671b4868f1a060f2ed64d8268cd22d514a84e63", "commits": [{"sha": "afb8afe306c7893d93d383a06e4d9df53b41bf47", "author": {"email": "47cb89439b2d6961b59dff4298e837f67aa77389@gmail.com", "name": "Tomonori Tamagawa"}, "message": "Update ID 949438177,, - screen_name: chomado, - name: ちょまど@初詣おみくじ凶, - description: ( *゚▽゚* っ)З腐女子!絵描き!| H26新卒文系SE (入社して4ヶ月目の8月にSIer(適応障害になった)を辞職し開発者に転職) | H26秋応用情報合格!| 自作bot (in PHP) chomado_bot | プログラミングガチ初心者, - location:", "distinct": true, "url": "https://api.github.com/repos/xtuaok/twitter_track_following/commits/afb8afe306c7893d93d383a06e4d9df53b41bf47"}]}, "public": true, "created_at": "2015-01-01T15:00:10Z"}' - expected_objects = [{ - 'id': '2489651127', - 'type': 'PushEvent', - 'actor': { - 'id': 1459915, - 'login': 'xtuaok', - 'gravatar_id': '', - 'url': 'https://api.github.com/users/xtuaok', - 'avatar_url': 'https://avatars.githubusercontent.com/u/1459915?' - }, - 'repo': { - 'id': 6719841, - 'name': 'xtuaok/twitter_track_following', - 'url': 'https://api.github.com/repos/xtuaok/twitter_track_following' - }, - 'payload': { - 'push_id': - 536864008, - 'size': - 1, - 'distinct_size': - 1, - 'ref': - 'refs/heads/xtuaok', - 'head': - 'afb8afe306c7893d93d383a06e4d9df53b41bf47', - 'before': - '4671b4868f1a060f2ed64d8268cd22d514a84e63', - 'commits': [{ - 'sha': - 'afb8afe306c7893d93d383a06e4d9df53b41bf47', - 'author': { - 'email': - '47cb89439b2d6961b59dff4298e837f67aa77389@gmail.com', - 'name': - 'Tomonori Tamagawa' - }, - 'message': - 'Update ID 949438177,, - screen_name: chomado, - name: ちょまど@初詣おみくじ凶, - description: ( *゚▽゚* っ)З腐女子!絵描き!| H26新卒文系SE (入社して4ヶ月目の8月にSIer(適応障害になった)を辞職し開発者に転職) | H26秋応用情報合格!| 自作bot (in PHP) chomado_bot | プログラミングガチ初心者, - location:', - 'distinct': - True, - 'url': - 'https://api.github.com/repos/xtuaok/twitter_track_following/commits/afb8afe306c7893d93d383a06e4d9df53b41bf47' - }] - }, - 'public': True, - 'created_at': '2015-01-01T15:00:10Z' - }] + expected_objects = [ + { + "id": "2489651127", + "type": "PushEvent", + "actor": { + "id": 1459915, + "login": "xtuaok", + "gravatar_id": "", + "url": "https://api.github.com/users/xtuaok", + "avatar_url": "https://avatars.githubusercontent.com/u/1459915?", + }, + "repo": { + "id": 6719841, + "name": "xtuaok/twitter_track_following", + "url": "https://api.github.com/repos/xtuaok/twitter_track_following", + }, + "payload": { + "push_id": 536864008, + "size": 1, + "distinct_size": 1, + "ref": "refs/heads/xtuaok", + "head": "afb8afe306c7893d93d383a06e4d9df53b41bf47", + "before": "4671b4868f1a060f2ed64d8268cd22d514a84e63", + "commits": [ + { + "sha": "afb8afe306c7893d93d383a06e4d9df53b41bf47", + "author": { + "email": "47cb89439b2d6961b59dff4298e837f67aa77389@gmail.com", + "name": "Tomonori Tamagawa", + }, + "message": "Update ID 949438177,, - screen_name: chomado, - name: ちょまど@初詣おみくじ凶, - description: ( *゚▽゚* っ)З腐女子!絵描き!| H26新卒文系SE (入社して4ヶ月目の8月にSIer(適応障害になった)を辞職し開発者に転職) | H26秋応用情報合格!| 自作bot (in PHP) chomado_bot | プログラミングガチ初心者, - location:", + "distinct": True, + "url": "https://api.github.com/repos/xtuaok/twitter_track_following/commits/afb8afe306c7893d93d383a06e4d9df53b41bf47", + } + ], + }, + "public": True, + "created_at": "2015-01-01T15:00:10Z", + } + ] return line, expected_objects @pytest.fixture def generate_random_ndjson(rand_gen): - def _generate_random_ndjson(lines: int = 10): return [ - json.dumps({"data_row": { - "id": rand_gen(str) - }}) for _ in range(lines) + json.dumps({"data_row": {"id": rand_gen(str)}}) + for _ in range(lines) ] return _generate_random_ndjson @@ -113,9 +88,7 @@ def _generate_random_ndjson(lines: int = 10): @pytest.fixture def mock_response(): - class MockResponse: - def __init__(self, text: str, exception: Exception = None) -> None: self._text = text self._exception = exception diff --git a/libs/labelbox/tests/unit/export_task/test_export_task.py b/libs/labelbox/tests/unit/export_task/test_export_task.py index 50f08191b..ac84a875b 100644 --- a/libs/labelbox/tests/unit/export_task/test_export_task.py +++ b/libs/labelbox/tests/unit/export_task/test_export_task.py @@ -6,9 +6,8 @@ class TestExportTask: - def test_export_task(self): - with patch('requests.get') as mock_requests_get: + with patch("requests.get") as mock_requests_get: mock_task = MagicMock() mock_task.client.execute.side_effect = [ { @@ -16,15 +15,9 @@ def test_export_task(self): "exportMetadataHeader": { "total_size": 1, "total_lines": 1, - "lines": { - "start": 0, - "end": 1 - }, - "offsets": { - "start": 0, - "end": 0 - }, - "file": "file" + "lines": {"start": 0, "end": 1}, + "offsets": {"start": 0, "end": 0}, + "file": "file", } } }, @@ -33,15 +26,9 @@ def test_export_task(self): "exportFileFromOffset": { "total_size": 1, "total_lines": 1, - "lines": { - "start": 0, - "end": 1 - }, - "offsets": { - "start": 0, - "end": 0 - }, - "file": "file" + "lines": {"start": 0, "end": 1}, + "offsets": {"start": 0, "end": 0}, + "file": "file", } } }, @@ -49,8 +36,7 @@ def test_export_task(self): mock_task.status = "COMPLETE" data = { "data_row": { - "raw_data": - """ + "raw_data": """ {"raw_text":"}{"} {"raw_text":"\\nbad"} """ @@ -76,7 +62,7 @@ def test_get_buffered_stream_failed(self): export_task.get_buffered_stream() def test_get_buffered_stream(self): - with patch('requests.get') as mock_requests_get: + with patch("requests.get") as mock_requests_get: mock_task = MagicMock() mock_task.client.execute.side_effect = [ { @@ -84,15 +70,9 @@ def test_get_buffered_stream(self): "exportMetadataHeader": { "total_size": 1, "total_lines": 1, - "lines": { - "start": 0, - "end": 1 - }, - "offsets": { - "start": 0, - "end": 0 - }, - "file": "file" + "lines": {"start": 0, "end": 1}, + "offsets": {"start": 0, "end": 0}, + "file": "file", } } }, @@ -101,15 +81,9 @@ def test_get_buffered_stream(self): "exportFileFromOffset": { "total_size": 1, "total_lines": 1, - "lines": { - "start": 0, - "end": 1 - }, - "offsets": { - "start": 0, - "end": 0 - }, - "file": "file" + "lines": {"start": 0, "end": 1}, + "offsets": {"start": 0, "end": 0}, + "file": "file", } } }, @@ -117,8 +91,7 @@ def test_get_buffered_stream(self): mock_task.status = "COMPLETE" data = { "data_row": { - "raw_data": - """ + "raw_data": """ {"raw_text":"}{"} {"raw_text":"\\nbad"} """ @@ -128,11 +101,13 @@ def test_get_buffered_stream(self): mock_requests_get.return_value.content = "b" export_task = ExportTask(mock_task, is_export_v2=True) output_data = [] - export_task.get_buffered_stream().start(stream_handler=lambda x: output_data.append(x.json)) + export_task.get_buffered_stream().start( + stream_handler=lambda x: output_data.append(x.json) + ) assert data == output_data[0] def test_export_task_bad_offsets(self): - with patch('requests.get') as mock_requests_get: + with patch("requests.get") as mock_requests_get: mock_task = MagicMock() mock_task.client.execute.side_effect = [ { @@ -140,15 +115,9 @@ def test_export_task_bad_offsets(self): "exportMetadataHeader": { "total_size": 1, "total_lines": 1, - "lines": { - "start": 0, - "end": 1 - }, - "offsets": { - "start": 0, - "end": 0 - }, - "file": "file" + "lines": {"start": 0, "end": 1}, + "offsets": {"start": 0, "end": 0}, + "file": "file", } } }, @@ -157,15 +126,9 @@ def test_export_task_bad_offsets(self): "exportFileFromOffset": { "total_size": 1, "total_lines": 1, - "lines": { - "start": 0, - "end": 1 - }, - "offsets": { - "start": 0, - "end": 0 - }, - "file": "file" + "lines": {"start": 0, "end": 1}, + "offsets": {"start": 0, "end": 0}, + "file": "file", } } }, @@ -173,21 +136,17 @@ def test_export_task_bad_offsets(self): mock_task.status = "COMPLETE" data = { "data_row": { - "id": - "clwb6wvpv3mpx0712aafl9m00", - "external_id": - "43cdad5e-1fcf-450d-ad72-df4460edf973", - "global_key": - "9ab56c5a-5c2f-45ae-8e21-e53eb415cefe", - "row_data": - "{\"type\":\"application/vnd.labelbox.conversational\",\"version\":1,\"messages\":[{\"messageId\":\"message-0\",\"timestampUsec\":1530718491,\"content\":\"The minimum value of $3 \\\\cos x + 4 \\\\sin x + 8$ is\",\"user\":{\"userId\":\"prompt\",\"name\":\"prompt\"},\"align\":\"left\",\"canLabel\":true}],\"modelOutputs\":[{\"title\":\"Response 1\",\"content\":\"To find the minimum value of the expression $3 \\\\cos x + 4 \\\\sin x + 8$, we can use the fact that $a\\\\cos x+b\\\\sin x=\\\\sqrt{a^2+b^2}\\\\left(\\\\frac{a}{\\\\sqrt{a^2+b^2}}\\\\cos x+\\\\frac{b}{\\\\sqrt{a^2+b^2}}\\\\sin x\\\\right)$. This allows us to rewrite the expression as:\\n\\n$3\\\\cos x+4\\\\sin x+8=\\\\sqrt{3^2+4^2}\\\\left(\\\\frac{3}{\\\\sqrt{3^2+4^2}}\\\\cos x+\\\\frac{4}{\\\\sqrt{3^2+4^2}}\\\\sin x\\\\right)+8=5\\\\left(\\\\frac{3}{5}\\\\cos x+\\\\frac{4}{5}\\\\sin x\\\\right)+8$\\n\\nNow, let's consider the expression $\\\\frac{3}{5}\\\\cos x+\\\\frac{4}{5}\\\\sin x$. Since $\\\\left(\\\\frac{3}{5}\\\\right)^2+\\\\left(\\\\frac{4}{5}\\\\right)^2=1$, we can write $\\\\frac{3}{5}=\\\\cos\\\\theta$ and $\\\\frac{4}{5}=\\\\sin\\\\theta$ for some angle $\\\\theta$. Then:\\n\\n$\\\\frac{3}{5}\\\\cos x+\\\\frac{4}{5}\\\\sin x=\\\\cos\\\\theta\\\\cos x+\\\\sin\\\\theta\\\\sin x=\\\\cos(x-\\\\theta)$\\n\\nSo, the original expression can be written as:\\n\\n$5\\\\cos(x-\\\\theta)+8$\\n\\nSince the minimum value of $\\\\cos(x-\\\\theta)$ is $-1$, the minimum value of the original expression is:\\n\\n$5(-1)+8=-5+8=3$\\n\\nTherefore, the minimum value of $3\\\\cos x + 4\\\\sin x + 8$ is $\\\\boxed{3}$.\",\"modelConfigName\":\"null\"},{\"title\":\"Response 2\",\"content\":\"A nice math question!\\n\\nTo find the minimum value of $3 \\\\cos x + 4 \\\\sin x + 8$, we can use the fact that $a\\\\cos x + b\\\\sin x = \\\\sqrt{a^2 + b^2} \\\\cos(x - \\\\alpha)$, where $\\\\alpha = \\\\tan^{-1}\\\\left(\\\\frac{b}{a}\\\\right)$.\\n\\nIn this case, $a = 3$ and $b = 4$, so $\\\\alpha = \\\\tan^{-1}\\\\left(\\\\frac{4}{3}\\\\right)$.\\n\\nSo, we have:\\n\\n$$3 \\\\cos x + 4 \\\\sin x + 8 = \\\\sqrt{3^2 + 4^2} \\\\cos(x - \\\\alpha) + 8 = 5 \\\\cos(x - \\\\alpha) + 8$$\\n\\nNow, the minimum value of $\\\\cos(x - \\\\alpha)$ is $-1$, so the minimum value of $5 \\\\cos(x - \\\\alpha) + 8$ is:\\n\\n$$5(-1) + 8 = -5 + 8 = 3$$\\n\\nTherefore, the minimum value of $3 \\\\cos x + 4 \\\\sin x + 8$ is $\\\\boxed{3}$.\\n\\nLet me know if you have any questions or need further clarification!\",\"modelConfigName\":\"null\"}]}" + "id": "clwb6wvpv3mpx0712aafl9m00", + "external_id": "43cdad5e-1fcf-450d-ad72-df4460edf973", + "global_key": "9ab56c5a-5c2f-45ae-8e21-e53eb415cefe", + "row_data": '{"type":"application/vnd.labelbox.conversational","version":1,"messages":[{"messageId":"message-0","timestampUsec":1530718491,"content":"The minimum value of $3 \\\\cos x + 4 \\\\sin x + 8$ is","user":{"userId":"prompt","name":"prompt"},"align":"left","canLabel":true}],"modelOutputs":[{"title":"Response 1","content":"To find the minimum value of the expression $3 \\\\cos x + 4 \\\\sin x + 8$, we can use the fact that $a\\\\cos x+b\\\\sin x=\\\\sqrt{a^2+b^2}\\\\left(\\\\frac{a}{\\\\sqrt{a^2+b^2}}\\\\cos x+\\\\frac{b}{\\\\sqrt{a^2+b^2}}\\\\sin x\\\\right)$. This allows us to rewrite the expression as:\\n\\n$3\\\\cos x+4\\\\sin x+8=\\\\sqrt{3^2+4^2}\\\\left(\\\\frac{3}{\\\\sqrt{3^2+4^2}}\\\\cos x+\\\\frac{4}{\\\\sqrt{3^2+4^2}}\\\\sin x\\\\right)+8=5\\\\left(\\\\frac{3}{5}\\\\cos x+\\\\frac{4}{5}\\\\sin x\\\\right)+8$\\n\\nNow, let\'s consider the expression $\\\\frac{3}{5}\\\\cos x+\\\\frac{4}{5}\\\\sin x$. Since $\\\\left(\\\\frac{3}{5}\\\\right)^2+\\\\left(\\\\frac{4}{5}\\\\right)^2=1$, we can write $\\\\frac{3}{5}=\\\\cos\\\\theta$ and $\\\\frac{4}{5}=\\\\sin\\\\theta$ for some angle $\\\\theta$. Then:\\n\\n$\\\\frac{3}{5}\\\\cos x+\\\\frac{4}{5}\\\\sin x=\\\\cos\\\\theta\\\\cos x+\\\\sin\\\\theta\\\\sin x=\\\\cos(x-\\\\theta)$\\n\\nSo, the original expression can be written as:\\n\\n$5\\\\cos(x-\\\\theta)+8$\\n\\nSince the minimum value of $\\\\cos(x-\\\\theta)$ is $-1$, the minimum value of the original expression is:\\n\\n$5(-1)+8=-5+8=3$\\n\\nTherefore, the minimum value of $3\\\\cos x + 4\\\\sin x + 8$ is $\\\\boxed{3}$.","modelConfigName":"null"},{"title":"Response 2","content":"A nice math question!\\n\\nTo find the minimum value of $3 \\\\cos x + 4 \\\\sin x + 8$, we can use the fact that $a\\\\cos x + b\\\\sin x = \\\\sqrt{a^2 + b^2} \\\\cos(x - \\\\alpha)$, where $\\\\alpha = \\\\tan^{-1}\\\\left(\\\\frac{b}{a}\\\\right)$.\\n\\nIn this case, $a = 3$ and $b = 4$, so $\\\\alpha = \\\\tan^{-1}\\\\left(\\\\frac{4}{3}\\\\right)$.\\n\\nSo, we have:\\n\\n$$3 \\\\cos x + 4 \\\\sin x + 8 = \\\\sqrt{3^2 + 4^2} \\\\cos(x - \\\\alpha) + 8 = 5 \\\\cos(x - \\\\alpha) + 8$$\\n\\nNow, the minimum value of $\\\\cos(x - \\\\alpha)$ is $-1$, so the minimum value of $5 \\\\cos(x - \\\\alpha) + 8$ is:\\n\\n$$5(-1) + 8 = -5 + 8 = 3$$\\n\\nTherefore, the minimum value of $3 \\\\cos x + 4 \\\\sin x + 8$ is $\\\\boxed{3}$.\\n\\nLet me know if you have any questions or need further clarification!","modelConfigName":"null"}]}', }, "media_attributes": { "asset_type": "conversational", "mime_type": "application/vnd.labelbox.conversational", "labelable_ids": ["message-0"], - "message_count": 1 - } + "message_count": 1, + }, } mock_requests_get.return_value.text = json.dumps(data) mock_requests_get.return_value.content = "b" diff --git a/libs/labelbox/tests/unit/export_task/test_unit_file_converter.py b/libs/labelbox/tests/unit/export_task/test_unit_file_converter.py index 3f3af9521..81e9eb60f 100644 --- a/libs/labelbox/tests/unit/export_task/test_unit_file_converter.py +++ b/libs/labelbox/tests/unit/export_task/test_unit_file_converter.py @@ -12,7 +12,6 @@ class TestFileConverter: - def test_with_correct_ndjson(self, tmp_path, generate_random_ndjson): directory = tmp_path / "file-converter" directory.mkdir() @@ -24,8 +23,9 @@ def test_with_correct_ndjson(self, tmp_path, generate_random_ndjson): client=MagicMock(), task_id="task-id", stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), + metadata_header=_MetadataHeader( + total_size=len(file_content), total_lines=line_count + ), ), file_info=_MetadataFileInfo( offsets=Range(start=0, end=len(file_content) - 1), @@ -55,8 +55,9 @@ def test_with_no_newline_at_end(self, tmp_path, generate_random_ndjson): client=MagicMock(), task_id="task-id", stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), + metadata_header=_MetadataHeader( + total_size=len(file_content), total_lines=line_count + ), ), file_info=_MetadataFileInfo( offsets=Range(start=0, end=len(file_content) - 1), diff --git a/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_line.py b/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_line.py index 1dba056fa..37c93647e 100644 --- a/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_line.py +++ b/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_line.py @@ -8,7 +8,6 @@ class TestFileRetrieverByLine: - def test_by_line_from_start(self, generate_random_ndjson, mock_response): line_count = 10 ndjson = generate_random_ndjson(line_count) @@ -19,25 +18,21 @@ def test_by_line_from_start(self, generate_random_ndjson, mock_response): return_value={ "task": { "exportFileFromLine": { - "offsets": { - "start": "0", - "end": len(file_content) - 1 - }, - "lines": { - "start": "0", - "end": str(line_count - 1) - }, + "offsets": {"start": "0", "end": len(file_content) - 1}, + "lines": {"start": "0", "end": str(line_count - 1)}, "file": "http://some-url.com/file.ndjson", } } - }) + } + ) mock_ctx = _TaskContext( client=mock_client, task_id="task-id", stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), + metadata_header=_MetadataHeader( + total_size=len(file_content), total_lines=line_count + ), ) with patch("requests.get", return_value=mock_response(file_content)): @@ -60,25 +55,21 @@ def test_by_line_from_middle(self, generate_random_ndjson, mock_response): return_value={ "task": { "exportFileFromLine": { - "offsets": { - "start": "0", - "end": len(file_content) - 1 - }, - "lines": { - "start": "0", - "end": str(line_count - 1) - }, + "offsets": {"start": "0", "end": len(file_content) - 1}, + "lines": {"start": "0", "end": str(line_count - 1)}, "file": "http://some-url.com/file.ndjson", } } - }) + } + ) mock_ctx = _TaskContext( client=mock_client, task_id="task-id", stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), + metadata_header=_MetadataHeader( + total_size=len(file_content), total_lines=line_count + ), ) line_start = 5 @@ -104,25 +95,21 @@ def test_by_line_from_last(self, generate_random_ndjson, mock_response): return_value={ "task": { "exportFileFromLine": { - "offsets": { - "start": "0", - "end": len(file_content) - 1 - }, - "lines": { - "start": "0", - "end": str(line_count - 1) - }, + "offsets": {"start": "0", "end": len(file_content) - 1}, + "lines": {"start": "0", "end": str(line_count - 1)}, "file": "http://some-url.com/file.ndjson", } } - }) + } + ) mock_ctx = _TaskContext( client=mock_client, task_id="task-id", stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), + metadata_header=_MetadataHeader( + total_size=len(file_content), total_lines=line_count + ), ) line_start = 9 diff --git a/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_offset.py b/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_offset.py index 07271d31c..870e03307 100644 --- a/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_offset.py +++ b/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_offset.py @@ -8,7 +8,6 @@ class TestFileRetrieverByOffset: - def test_by_offset_from_start(self, generate_random_ndjson, mock_response): line_count = 10 ndjson = generate_random_ndjson(line_count) @@ -19,25 +18,21 @@ def test_by_offset_from_start(self, generate_random_ndjson, mock_response): return_value={ "task": { "exportFileFromOffset": { - "offsets": { - "start": "0", - "end": len(file_content) - 1 - }, - "lines": { - "start": "0", - "end": str(line_count - 1) - }, + "offsets": {"start": "0", "end": len(file_content) - 1}, + "lines": {"start": "0", "end": str(line_count - 1)}, "file": "http://some-url.com/file.ndjson", } } - }) + } + ) mock_ctx = _TaskContext( client=mock_client, task_id="task-id", stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), + metadata_header=_MetadataHeader( + total_size=len(file_content), total_lines=line_count + ), ) with patch("requests.get", return_value=mock_response(file_content)): @@ -60,25 +55,21 @@ def test_by_offset_from_middle(self, generate_random_ndjson, mock_response): return_value={ "task": { "exportFileFromOffset": { - "offsets": { - "start": "0", - "end": len(file_content) - 1 - }, - "lines": { - "start": "0", - "end": str(line_count - 1) - }, + "offsets": {"start": "0", "end": len(file_content) - 1}, + "lines": {"start": "0", "end": str(line_count - 1)}, "file": "http://some-url.com/file.ndjson", } } - }) + } + ) mock_ctx = _TaskContext( client=mock_client, task_id="task-id", stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), + metadata_header=_MetadataHeader( + total_size=len(file_content), total_lines=line_count + ), ) line_start = 5 diff --git a/libs/labelbox/tests/unit/export_task/test_unit_json_converter.py b/libs/labelbox/tests/unit/export_task/test_unit_json_converter.py index 249eff0f5..f5ccf26fb 100644 --- a/libs/labelbox/tests/unit/export_task/test_unit_json_converter.py +++ b/libs/labelbox/tests/unit/export_task/test_unit_json_converter.py @@ -1,10 +1,14 @@ from unittest.mock import MagicMock -from labelbox.schema.export_task import Converter, JsonConverter, Range, _MetadataFileInfo +from labelbox.schema.export_task import ( + Converter, + JsonConverter, + Range, + _MetadataFileInfo, +) class TestJsonConverter: - def test_with_correct_ndjson(self, generate_random_ndjson): line_count = 10 ndjson = generate_random_ndjson(line_count) @@ -71,8 +75,9 @@ def test_from_offset(self, generate_random_ndjson): for idx, output in enumerate(converter.convert(input_args)): assert output.current_line == line_start + idx assert output.current_offset == current_offset - assert output.json_str == ndjson[line_start + - idx][skipped_bytes:] + assert ( + output.json_str == ndjson[line_start + idx][skipped_bytes:] + ) current_offset += len(output.json_str) + 1 skipped_bytes = 0 @@ -100,7 +105,8 @@ def test_from_offset_last_line(self, generate_random_ndjson): for idx, output in enumerate(converter.convert(input_args)): assert output.current_line == line_start + idx assert output.current_offset == current_offset - assert output.json_str == ndjson[line_start + - idx][skipped_bytes:] + assert ( + output.json_str == ndjson[line_start + idx][skipped_bytes:] + ) current_offset += len(output.json_str) + 1 skipped_bytes = 0 diff --git a/libs/labelbox/tests/unit/schema/test_user_group.py b/libs/labelbox/tests/unit/schema/test_user_group.py index 0eb0381d6..65584f8ef 100644 --- a/libs/labelbox/tests/unit/schema/test_user_group.py +++ b/libs/labelbox/tests/unit/schema/test_user_group.py @@ -2,7 +2,13 @@ from collections import defaultdict from unittest.mock import MagicMock from labelbox import Client -from labelbox.exceptions import ResourceConflict, ResourceCreationError, ResourceNotFoundError, MalformedQueryException, UnprocessableEntityError +from labelbox.exceptions import ( + ResourceConflict, + ResourceCreationError, + ResourceNotFoundError, + MalformedQueryException, + UnprocessableEntityError, +) from labelbox.schema.project import Project from labelbox.schema.user import User from labelbox.schema.user_group import UserGroup, UserGroupColor @@ -10,6 +16,7 @@ from labelbox.schema.ontology_kind import EditorTaskType from labelbox.schema.media_type import MediaType + @pytest.fixture def group_user(): user_values = defaultdict(lambda: None) @@ -30,7 +37,6 @@ def group_project(): class TestUserGroupColor: - def test_user_group_color_values(self): assert UserGroupColor.BLUE.value == "9EC5FF" assert UserGroupColor.PURPLE.value == "CEB8FF" @@ -44,12 +50,11 @@ def test_user_group_color_values(self): class TestUserGroup: - def setup_method(self): self.client = MagicMock(Client) self.client.enable_experimental = True self.group = UserGroup(client=self.client) - + def test_constructor_experimental_needed(self): client = MagicMock(Client) client.enable_experimental = False @@ -74,36 +79,20 @@ def test_update_with_exception_name(self): def test_get(self): projects = [ - { - "id": "project_id_1", - "name": "project_1" - }, - { - "id": "project_id_2", - "name": "project_2" - } + {"id": "project_id_1", "name": "project_1"}, + {"id": "project_id_2", "name": "project_2"}, ] group_members = [ - { - "id": "user_id_1", - "email": "email_1" - }, - { - "id": "user_id_2", - "email": "email_2" - } + {"id": "user_id_1", "email": "email_1"}, + {"id": "user_id_2", "email": "email_2"}, ] self.client.execute.return_value = { "userGroup": { "id": "group_id", "name": "Test Group", "color": "4ED2F9", - "projects": { - "nodes": projects - }, - "members": { - "nodes": group_members - } + "projects": {"nodes": projects}, + "members": {"nodes": group_members}, } } group = UserGroup(self.client) @@ -135,8 +124,8 @@ def test_update(self, group_user, group_project): group.id = "group_id" group.name = "Test Group" group.color = UserGroupColor.BLUE - group.users = { group_user } - group.projects = { group_project } + group.users = {group_user} + group.projects = {group_project} updated_group = group.update() @@ -209,15 +198,11 @@ def test_create(self, group_user, group_project): group = self.group group.name = "New Group" group.color = UserGroupColor.PINK - group.users = { group_user } - group.projects = { group_project } + group.users = {group_user} + group.projects = {group_project} self.client.execute.return_value = { - "createUserGroup": { - "group": { - "id": "group_id" - } - } + "createUserGroup": {"group": {"id": "group_id"}} } created_group = group.create() execute = self.client.execute.call_args[0] @@ -237,7 +222,7 @@ def test_create(self, group_user, group_project): assert list(created_group.users)[0].uid == "user_id" assert len(created_group.projects) == 1 assert list(created_group.projects)[0].uid == "project_id" - + def test_create_resource_creation_error(self): self.client.execute.side_effect = ResourceConflict("Error") group = UserGroup(self.client) @@ -251,9 +236,7 @@ def test_delete(self): group.id = "group_id" self.client.execute.return_value = { - "deleteUserGroup": { - "success": True - } + "deleteUserGroup": {"success": True} } deleted = group.delete() execute = self.client.execute.call_args[0] @@ -287,75 +270,78 @@ def test_user_groups_empty(self): def test_user_groups(self): self.client.execute.return_value = { "userGroups": { - "nextCursor": - None, - "nodes": [{ - "id": "group_id_1", - "name": "Group 1", - "color": "9EC5FF", - "projects": { - "nodes": [{ - "id": "project_id_1", - "name": "Project 1" - }, { - "id": "project_id_2", - "name": "Project 2" - }] + "nextCursor": None, + "nodes": [ + { + "id": "group_id_1", + "name": "Group 1", + "color": "9EC5FF", + "projects": { + "nodes": [ + {"id": "project_id_1", "name": "Project 1"}, + {"id": "project_id_2", "name": "Project 2"}, + ] + }, + "members": { + "nodes": [ + { + "id": "user_id_1", + "email": "user1@example.com", + }, + { + "id": "user_id_2", + "email": "user2@example.com", + }, + ] + }, }, - "members": { - "nodes": [{ - "id": "user_id_1", - "email": "user1@example.com" - }, { - "id": "user_id_2", - "email": "user2@example.com" - }] - } - }, { - "id": "group_id_2", - "name": "Group 2", - "color": "9EC5FF", - "projects": { - "nodes": [{ - "id": "project_id_3", - "name": "Project 3" - }, { - "id": "project_id_4", - "name": "Project 4" - }] + { + "id": "group_id_2", + "name": "Group 2", + "color": "9EC5FF", + "projects": { + "nodes": [ + {"id": "project_id_3", "name": "Project 3"}, + {"id": "project_id_4", "name": "Project 4"}, + ] + }, + "members": { + "nodes": [ + { + "id": "user_id_3", + "email": "user3@example.com", + }, + { + "id": "user_id_4", + "email": "user4@example.com", + }, + ] + }, }, - "members": { - "nodes": [{ - "id": "user_id_3", - "email": "user3@example.com" - }, { - "id": "user_id_4", - "email": "user4@example.com" - }] - } - }, { - "id": "group_id_3", - "name": "Group 3", - "color": "9EC5FF", - "projects": { - "nodes": [{ - "id": "project_id_5", - "name": "Project 5" - }, { - "id": "project_id_6", - "name": "Project 6" - }] + { + "id": "group_id_3", + "name": "Group 3", + "color": "9EC5FF", + "projects": { + "nodes": [ + {"id": "project_id_5", "name": "Project 5"}, + {"id": "project_id_6", "name": "Project 6"}, + ] + }, + "members": { + "nodes": [ + { + "id": "user_id_5", + "email": "user5@example.com", + }, + { + "id": "user_id_6", + "email": "user6@example.com", + }, + ] + }, }, - "members": { - "nodes": [{ - "id": "user_id_5", - "email": "user5@example.com" - }, { - "id": "user_id_6", - "email": "user6@example.com" - }] - } - }] + ], } } @@ -389,4 +375,5 @@ def test_user_groups(self): if __name__ == "__main__": import subprocess + subprocess.call(["pytest", "-v", __file__]) diff --git a/libs/labelbox/tests/unit/test_annotation_import.py b/libs/labelbox/tests/unit/test_annotation_import.py index ff0835467..d4642f17b 100644 --- a/libs/labelbox/tests/unit/test_annotation_import.py +++ b/libs/labelbox/tests/unit/test_annotation_import.py @@ -10,69 +10,59 @@ def test_data_row_validation_errors(): "schemaId": "ckrb1sfl8099g0y91cxbd5ftb", }, "schemaId": "c123", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, + "dataRow": {"globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"}, }, { "answer": { "schemaId": "ckrb1sfl8099g0y91cxbd5ftb", }, "schemaId": "c123", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, + "dataRow": {"globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"}, }, { "answer": { "schemaId": "ckrb1sfl8099g0y91cxbd5ftb", }, "schemaId": "c123", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, + "dataRow": {"globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"}, }, { "answer": { "schemaId": "ckrb1sfl8099g0y91cxbd5ftb", }, "schemaId": "c123", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, + "dataRow": {"globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"}, }, { "answer": { "schemaId": "ckrb1sfl8099g0y91cxbd5ftb", }, "schemaId": "c123", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, + "dataRow": {"globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"}, }, ] # Set up data for validation errors # Invalid: Remove 'dataRow' part entirely - del predictions[0]['dataRow'] + del predictions[0]["dataRow"] # Invalid: Set both id and globalKey - predictions[1]['dataRow'] = { - 'id': 'some id', - 'globalKey': 'some global key' + predictions[1]["dataRow"] = { + "id": "some id", + "globalKey": "some global key", } # Invalid: Set both id and globalKey to None - predictions[2]['dataRow'] = {'id': None, 'globalKey': None} + predictions[2]["dataRow"] = {"id": None, "globalKey": None} # Valid - predictions[3]['dataRow'] = { - 'id': 'some id', + predictions[3]["dataRow"] = { + "id": "some id", } # Valid - predictions[4]['dataRow'] = { - 'globalKey': 'some global key', + predictions[4]["dataRow"] = { + "globalKey": "some global key", } with pytest.raises(ValueError) as exc_info: @@ -80,6 +70,12 @@ def test_data_row_validation_errors(): exception_str = str(exc_info.value) assert "Found 3 annotations with errors" in exception_str assert "'dataRow' is missing in" in exception_str - assert "Must provide only one of 'id' or 'globalKey' for 'dataRow'" in exception_str - assert "'dataRow': {'id': 'some id', 'globalKey': 'some global key'}" in exception_str + assert ( + "Must provide only one of 'id' or 'globalKey' for 'dataRow'" + in exception_str + ) + assert ( + "'dataRow': {'id': 'some id', 'globalKey': 'some global key'}" + in exception_str + ) assert "'dataRow': {'id': None, 'globalKey': None}" in exception_str diff --git a/libs/labelbox/tests/unit/test_data_row_upsert_data.py b/libs/labelbox/tests/unit/test_data_row_upsert_data.py index b8c68c0af..11cc4153f 100644 --- a/libs/labelbox/tests/unit/test_data_row_upsert_data.py +++ b/libs/labelbox/tests/unit/test_data_row_upsert_data.py @@ -1,32 +1,37 @@ from unittest.mock import MagicMock, patch import pytest -from labelbox.schema.internal.data_row_upsert_item import (DataRowUpsertItem, - DataRowCreateItem) +from labelbox.schema.internal.data_row_upsert_item import ( + DataRowUpsertItem, + DataRowCreateItem, +) from labelbox.schema.identifiable import UniqueId, GlobalKey from labelbox.schema.asset_attachment import AttachmentType from labelbox.schema.dataset import Dataset -from labelbox.schema.internal.descriptor_file_creator import DescriptorFileCreator +from labelbox.schema.internal.descriptor_file_creator import ( + DescriptorFileCreator, +) from labelbox.schema.data_row import DataRow @pytest.fixture def data_row_create_items(): - dataset_id = 'test_dataset' + dataset_id = "test_dataset" items = [ { "row_data": "http://my_site.com/photos/img_01.jpg", "global_key": "global_key1", "external_id": "ex_id1", - "attachments": [{ - "type": AttachmentType.RAW_TEXT, - "name": "att1", - "value": "test1" - }], - "metadata": [{ - "name": "tag", - "value": "tag value" - },] + "attachments": [ + { + "type": AttachmentType.RAW_TEXT, + "name": "att1", + "value": "test1", + } + ], + "metadata": [ + {"name": "tag", "value": "tag value"}, + ], }, ] return dataset_id, items @@ -34,7 +39,7 @@ def data_row_create_items(): @pytest.fixture def data_row_create_items_row_data_none(): - dataset_id = 'test_dataset' + dataset_id = "test_dataset" items = [ { "row_data": None, @@ -45,16 +50,10 @@ def data_row_create_items_row_data_none(): @pytest.fixture def data_row_update_items(): - dataset_id = 'test_dataset' + dataset_id = "test_dataset" items = [ - { - "key": GlobalKey("global_key1"), - "global_key": "global_key1_updated" - }, - { - "key": UniqueId('unique_id1'), - "external_id": "ex_id1_updated" - }, + {"key": GlobalKey("global_key1"), "global_key": "global_key1_updated"}, + {"key": UniqueId("unique_id1"), "external_id": "ex_id1_updated"}, ] return dataset_id, items @@ -84,22 +83,20 @@ def test_data_row_create_items_not_updateable(data_row_update_items): def test_upsert_is_empty(): - item = DataRowUpsertItem(id={ - "id": UniqueId, - "value": UniqueId("123") - }, - payload={}) + item = DataRowUpsertItem( + id={"id": UniqueId, "value": UniqueId("123")}, payload={} + ) assert item.is_empty() - item = DataRowUpsertItem(id={ - "id": UniqueId, - "value": UniqueId("123") - }, - payload={"dataset_id": "test_dataset"}) + item = DataRowUpsertItem( + id={"id": UniqueId, "value": UniqueId("123")}, + payload={"dataset_id": "test_dataset"}, + ) assert item.is_empty() item = DataRowUpsertItem( - id={}, payload={"row_data": "http://my_site.com/photos/img_01.jpg"}) + id={}, payload={"row_data": "http://my_site.com/photos/img_01.jpg"} + ) assert not item.is_empty() @@ -117,29 +114,26 @@ def test_create_is_empty(): assert item.is_empty() item = DataRowCreateItem( - id={}, payload={"row_data": "http://my_site.com/photos/img_01.jpg"}) + id={}, payload={"row_data": "http://my_site.com/photos/img_01.jpg"} + ) assert not item.is_empty() item = DataRowCreateItem( id={}, - payload={DataRow.row_data: "http://my_site.com/photos/img_01.jpg"}) + payload={DataRow.row_data: "http://my_site.com/photos/img_01.jpg"}, + ) assert not item.is_empty() legacy_converstational_data_payload = { - "externalId": - "Convo-123", - "type": - "application/vnd.labelbox.conversational", - "conversationalData": [{ - "messageId": - "message-0", - "content": - "I love iphone! i just bought new iphone! :smiling_face_with_3_hearts: :calling:", - "user": { - "userId": "Bot 002", - "name": "Bot" - }, - }] + "externalId": "Convo-123", + "type": "application/vnd.labelbox.conversational", + "conversationalData": [ + { + "messageId": "message-0", + "content": "I love iphone! i just bought new iphone! :smiling_face_with_3_hearts: :calling:", + "user": {"userId": "Bot 002", "name": "Bot"}, + } + ], } item = DataRowCreateItem(id={}, payload=legacy_converstational_data_payload) assert not item.is_empty() @@ -154,20 +148,25 @@ def test_create_row_data_none(): ] client = MagicMock() dataset = Dataset( - client, { - "id": 'test_dataset', - "name": 'test_dataset', + client, + { + "id": "test_dataset", + "name": "test_dataset", "createdAt": "2021-06-01T00:00:00.000Z", "description": "test_dataset", "updatedAt": "2021-06-01T00:00:00.000Z", "rowCount": 0, - }) - - with patch.object(DescriptorFileCreator, - 'create', - return_value=["http://bar.com/chunk_uri"]): - with pytest.raises(ValueError, - match="Some items have an empty payload"): + }, + ) + + with patch.object( + DescriptorFileCreator, + "create", + return_value=["http://bar.com/chunk_uri"], + ): + with pytest.raises( + ValueError, match="Some items have an empty payload" + ): dataset.create_data_rows(items) client.execute.assert_not_called() diff --git a/libs/labelbox/tests/unit/test_exceptions.py b/libs/labelbox/tests/unit/test_exceptions.py index 69bcfbd77..4602fb984 100644 --- a/libs/labelbox/tests/unit/test_exceptions.py +++ b/libs/labelbox/tests/unit/test_exceptions.py @@ -3,11 +3,18 @@ from labelbox.exceptions import error_message_for_unparsed_graphql_error -@pytest.mark.parametrize('exception_message, expected_result', [ - ("Unparsed errors on query execution: [{'message': 'Cannot create model config for project because model setup is complete'}]", - "Cannot create model config for project because model setup is complete"), - ("blah blah blah", "Unknown error"), -]) +@pytest.mark.parametrize( + "exception_message, expected_result", + [ + ( + "Unparsed errors on query execution: [{'message': 'Cannot create model config for project because model setup is complete'}]", + "Cannot create model config for project because model setup is complete", + ), + ("blah blah blah", "Unknown error"), + ], +) def test_client_unparsed_exception_messages(exception_message, expected_result): - assert error_message_for_unparsed_graphql_error( - exception_message) == expected_result + assert ( + error_message_for_unparsed_graphql_error(exception_message) + == expected_result + ) diff --git a/libs/labelbox/tests/unit/test_label_data_type.py b/libs/labelbox/tests/unit/test_label_data_type.py index 737136a36..7bc32e37c 100644 --- a/libs/labelbox/tests/unit/test_label_data_type.py +++ b/libs/labelbox/tests/unit/test_label_data_type.py @@ -2,35 +2,36 @@ import pytest from pydantic import ValidationError -from labelbox.data.annotation_types.data.generic_data_row_data import GenericDataRowData +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) from labelbox.data.annotation_types.data.video import VideoData from labelbox.data.annotation_types.label import Label def test_generic_data_type(): data = { - 'global_key': - 'https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr', + "global_key": "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr", } label = Label(data=data) data = label.data assert isinstance(data, GenericDataRowData) - assert data.global_key == 'https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr' + assert ( + data.global_key + == "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr" + ) def test_generic_data_type_validations(): data = { - 'row_data': - 'https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr', + "row_data": "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr", } with pytest.raises(ValueError, match="Exactly one of"): Label(data=data) data = { - 'uid': - "abcd", - 'global_key': - 'https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr', + "uid": "abcd", + "global_key": "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr", } with pytest.raises(ValueError, match="Only one of"): Label(data=data) @@ -38,22 +39,26 @@ def test_generic_data_type_validations(): def test_video_data_type(): data = { - 'global_key': - 'https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr', + "global_key": "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr", } with pytest.warns(UserWarning, match="Use a dict"): label = Label(data=VideoData(**data)) data = label.data assert isinstance(data, VideoData) - assert data.global_key == 'https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr' + assert ( + data.global_key + == "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr" + ) def test_generic_data_row(): data = { - 'global_key': - 'https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr', + "global_key": "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr", } label = Label(data=GenericDataRowData(**data)) data = label.data assert isinstance(data, GenericDataRowData) - assert data.global_key == 'https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr' + assert ( + data.global_key + == "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr" + ) diff --git a/libs/labelbox/tests/unit/test_mal_import.py b/libs/labelbox/tests/unit/test_mal_import.py index 799944a13..3dc3eea56 100644 --- a/libs/labelbox/tests/unit/test_mal_import.py +++ b/libs/labelbox/tests/unit/test_mal_import.py @@ -11,35 +11,32 @@ def test_should_warn_user_about_unsupported_confidence(): labels = [ { - "bbox": { - "height": 428, - "left": 2089, - "top": 1251, - "width": 158 - }, - "classifications": [{ - "answer": [{ - "schemaId": "ckrb1sfl8099e0y919v260awv", - "confidence": 0.894 - }], - "schemaId": "ckrb1sfkn099c0y910wbo0p1a" - }], - "dataRow": { - "id": "ckrb1sf1i1g7i0ybcdc6oc8ct" - }, + "bbox": {"height": 428, "left": 2089, "top": 1251, "width": 158}, + "classifications": [ + { + "answer": [ + { + "schemaId": "ckrb1sfl8099e0y919v260awv", + "confidence": 0.894, + } + ], + "schemaId": "ckrb1sfkn099c0y910wbo0p1a", + } + ], + "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, "schemaId": "ckrb1sfjx099a0y914hl319ie", - "uuid": "d009925d-91a3-4f67-abd9-753453f5a584" + "uuid": "d009925d-91a3-4f67-abd9-753453f5a584", }, ] - with patch.object(MALPredictionImport, '_create_mal_import_from_bytes'): - with patch.object(logger, 'warning') as warning_mock: - MALPredictionImport.create_from_objects(client=MagicMock(), - project_id=id, - name=id, - predictions=labels) + with patch.object(MALPredictionImport, "_create_mal_import_from_bytes"): + with patch.object(logger, "warning") as warning_mock: + MALPredictionImport.create_from_objects( + client=MagicMock(), project_id=id, name=id, predictions=labels + ) warning_mock.assert_called_once() "Confidence scores are not supported in MAL Prediction Import" in warning_mock.call_args_list[ - 0].args[0] + 0 + ].args[0] def test_invalid_labels_format(): @@ -47,29 +44,25 @@ def test_invalid_labels_format(): id = str(uuid.uuid4()) label = { - "bbox": { - "height": 428, - "left": 2089, - "top": 1251, - "width": 158 - }, - "classifications": [{ - "answer": [{ - "schemaId": "ckrb1sfl8099e0y919v260awv", - "confidence": 0.894 - }], - "schemaId": "ckrb1sfkn099c0y910wbo0p1a" - }], - "dataRow": { - "id": "ckrb1sf1i1g7i0ybcdc6oc8ct" - }, + "bbox": {"height": 428, "left": 2089, "top": 1251, "width": 158}, + "classifications": [ + { + "answer": [ + { + "schemaId": "ckrb1sfl8099e0y919v260awv", + "confidence": 0.894, + } + ], + "schemaId": "ckrb1sfkn099c0y910wbo0p1a", + } + ], + "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, "schemaId": "ckrb1sfjx099a0y914hl319ie", - "uuid": "3a83db52-75e0-49af-a171-234ce604502a" + "uuid": "3a83db52-75e0-49af-a171-234ce604502a", } - with patch.object(MALPredictionImport, '_create_mal_import_from_bytes'): + with patch.object(MALPredictionImport, "_create_mal_import_from_bytes"): with pytest.raises(TypeError): - MALPredictionImport.create_from_objects(client=MagicMock(), - project_id=id, - name=id, - predictions=label) + MALPredictionImport.create_from_objects( + client=MagicMock(), project_id=id, name=id, predictions=label + ) diff --git a/libs/labelbox/tests/unit/test_ndjson_parsing.py b/libs/labelbox/tests/unit/test_ndjson_parsing.py index 832e41928..508e44d74 100644 --- a/libs/labelbox/tests/unit/test_ndjson_parsing.py +++ b/libs/labelbox/tests/unit/test_ndjson_parsing.py @@ -15,7 +15,7 @@ def test_loads(ndjson_content): def test_loads_bytes(ndjson_content): expected_line, expected_objects = ndjson_content - bytes_line = expected_line.encode('utf-8') + bytes_line = expected_line.encode("utf-8") parsed_line = parser.loads(bytes_line) assert parsed_line == expected_objects diff --git a/libs/labelbox/tests/unit/test_project.py b/libs/labelbox/tests/unit/test_project.py index 367f74296..5e5f99c57 100644 --- a/libs/labelbox/tests/unit/test_project.py +++ b/libs/labelbox/tests/unit/test_project.py @@ -32,15 +32,21 @@ def project_entity(): @pytest.mark.parametrize( - 'api_editor_task_type, expected_editor_task_type', - [(None, EditorTaskType.Missing), - ('MODEL_CHAT_EVALUATION', EditorTaskType.ModelChatEvaluation), - ('RESPONSE_CREATION', EditorTaskType.ResponseCreation), - ('OFFLINE_MODEL_CHAT_EVALUATION', - EditorTaskType.OfflineModelChatEvaluation), - ('NEW_TYPE', EditorTaskType.Missing)]) -def test_project_editor_task_type(api_editor_task_type, - expected_editor_task_type, project_entity): + "api_editor_task_type, expected_editor_task_type", + [ + (None, EditorTaskType.Missing), + ("MODEL_CHAT_EVALUATION", EditorTaskType.ModelChatEvaluation), + ("RESPONSE_CREATION", EditorTaskType.ResponseCreation), + ( + "OFFLINE_MODEL_CHAT_EVALUATION", + EditorTaskType.OfflineModelChatEvaluation, + ), + ("NEW_TYPE", EditorTaskType.Missing), + ], +) +def test_project_editor_task_type( + api_editor_task_type, expected_editor_task_type, project_entity +): client = MagicMock() project = Project( client, diff --git a/libs/labelbox/tests/unit/test_unit_delete_batch_data_row_metadata.py b/libs/labelbox/tests/unit/test_unit_delete_batch_data_row_metadata.py index 561f8d6b0..cd6eadd79 100644 --- a/libs/labelbox/tests/unit/test_unit_delete_batch_data_row_metadata.py +++ b/libs/labelbox/tests/unit/test_unit_delete_batch_data_row_metadata.py @@ -7,50 +7,44 @@ def test_dict_delete_data_row_batch(): obj = _DeleteBatchDataRowMetadata( data_row_identifier=UniqueId("abcd"), - schema_ids=["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"]) + schema_ids=["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"], + ) assert obj.model_dump() == { - "data_row_identifier": { - "id": "abcd", - "id_type": "ID" - }, + "data_row_identifier": {"id": "abcd", "id_type": "ID"}, "schema_ids": [ - "clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy" - ] + "clqh77tyk000008l2a9mjesa1", + "clqh784br000008jy0yuq04fy", + ], } obj = _DeleteBatchDataRowMetadata( data_row_identifier=GlobalKey("fegh"), - schema_ids=["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"]) + schema_ids=["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"], + ) assert obj.model_dump() == { - "data_row_identifier": { - "id": "fegh", - "id_type": "GKEY" - }, + "data_row_identifier": {"id": "fegh", "id_type": "GKEY"}, "schema_ids": [ - "clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy" - ] + "clqh77tyk000008l2a9mjesa1", + "clqh784br000008jy0yuq04fy", + ], } def test_dict_delete_data_row_batch_by_alias(): obj = _DeleteBatchDataRowMetadata( data_row_identifier=UniqueId("abcd"), - schema_ids=["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"]) + schema_ids=["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"], + ) assert obj.model_dump(by_alias=True) == { - "dataRowIdentifier": { - "id": "abcd", - "idType": "ID" - }, - "schemaIds": ["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"] + "dataRowIdentifier": {"id": "abcd", "idType": "ID"}, + "schemaIds": ["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"], } obj = _DeleteBatchDataRowMetadata( data_row_identifier=GlobalKey("fegh"), - schema_ids=["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"]) + schema_ids=["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"], + ) assert obj.model_dump(by_alias=True) == { - "dataRowIdentifier": { - "id": "fegh", - "idType": "GKEY" - }, - "schemaIds": ["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"] + "dataRowIdentifier": {"id": "fegh", "idType": "GKEY"}, + "schemaIds": ["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"], } diff --git a/libs/labelbox/tests/unit/test_unit_descriptor_file_creator.py b/libs/labelbox/tests/unit/test_unit_descriptor_file_creator.py index 630d80573..621317ddd 100644 --- a/libs/labelbox/tests/unit/test_unit_descriptor_file_creator.py +++ b/libs/labelbox/tests/unit/test_unit_descriptor_file_creator.py @@ -3,7 +3,9 @@ from unittest.mock import MagicMock, Mock import pytest -from labelbox.schema.internal.descriptor_file_creator import DescriptorFileCreator +from labelbox.schema.internal.descriptor_file_creator import ( + DescriptorFileCreator, +) def test_chunk_down_by_bytes_row_too_large(): @@ -14,8 +16,9 @@ def test_chunk_down_by_bytes_row_too_large(): chunk = [{"row_data": "a"}] max_chunk_size_bytes = 1 - res = descriptor_file_creator._chunk_down_by_bytes(chunk, - max_chunk_size_bytes) + res = descriptor_file_creator._chunk_down_by_bytes( + chunk, max_chunk_size_bytes + ) assert [x for x in res] == [json.dumps([{"row_data": "a"}])] @@ -27,14 +30,12 @@ def test_chunk_down_by_bytes_more_chunks(): chunk = [{"row_data": "a"}, {"row_data": "b"}] max_chunk_size_bytes = len(json.dumps(chunk).encode("utf-8")) - 1 - res = descriptor_file_creator._chunk_down_by_bytes(chunk, - max_chunk_size_bytes) + res = descriptor_file_creator._chunk_down_by_bytes( + chunk, max_chunk_size_bytes + ) assert [x for x in res] == [ - json.dumps([{ - "row_data": "a" - }]), json.dumps([{ - "row_data": "b" - }]) + json.dumps([{"row_data": "a"}]), + json.dumps([{"row_data": "b"}]), ] @@ -46,11 +47,9 @@ def test_chunk_down_by_bytes_one_chunk(): chunk = [{"row_data": "a"}, {"row_data": "b"}] max_chunk_size_bytes = len(json.dumps(chunk).encode("utf-8")) - res = descriptor_file_creator._chunk_down_by_bytes(chunk, - max_chunk_size_bytes) - assert [x for x in res - ] == [json.dumps([{ - "row_data": "a" - }, { - "row_data": "b" - }])] + res = descriptor_file_creator._chunk_down_by_bytes( + chunk, max_chunk_size_bytes + ) + assert [x for x in res] == [ + json.dumps([{"row_data": "a"}, {"row_data": "b"}]) + ] diff --git a/libs/labelbox/tests/unit/test_unit_entity_meta.py b/libs/labelbox/tests/unit/test_unit_entity_meta.py index d24f985d9..06278951b 100644 --- a/libs/labelbox/tests/unit/test_unit_entity_meta.py +++ b/libs/labelbox/tests/unit/test_unit_entity_meta.py @@ -5,7 +5,6 @@ def test_illegal_cache_cond1(): - class TestEntityA(DbObject): test_entity_b = Relationship.ToOne("TestEntityB", cache=True) @@ -14,12 +13,13 @@ class TestEntityA(DbObject): class TestEntityB(DbObject): another_entity = Relationship.ToOne("AnotherEntity", cache=True) - assert "`test_entity_a` caches `test_entity_b` which caches `['another_entity']`" in str( - exc_info.value) + assert ( + "`test_entity_a` caches `test_entity_b` which caches `['another_entity']`" + in str(exc_info.value) + ) def test_illegal_cache_cond2(): - class TestEntityD(DbObject): another_entity = Relationship.ToOne("AnotherEntity", cache=True) @@ -28,5 +28,7 @@ class TestEntityD(DbObject): class TestEntityC(DbObject): test_entity_d = Relationship.ToOne("TestEntityD", cache=True) - assert "`test_entity_c` caches `test_entity_d` which caches `['another_entity']`" in str( - exc_info.value) + assert ( + "`test_entity_c` caches `test_entity_d` which caches `['another_entity']`" + in str(exc_info.value) + ) diff --git a/libs/labelbox/tests/unit/test_unit_export_filters.py b/libs/labelbox/tests/unit/test_unit_export_filters.py index 5986ae44e..3be78152e 100644 --- a/libs/labelbox/tests/unit/test_unit_export_filters.py +++ b/libs/labelbox/tests/unit/test_unit_export_filters.py @@ -8,33 +8,39 @@ def test_ids_filter(): client = MagicMock() filters = {"data_row_ids": ["id1", "id2"], "batch_ids": ["b1", "b2"]} - assert build_filters(client, filters) == [{ - "ids": ["id1", "id2"], - "operator": "is", - "type": "data_row_id", - }, { - "ids": ["b1", "b2"], - "operator": "is", - "type": "batch", - }] + assert build_filters(client, filters) == [ + { + "ids": ["id1", "id2"], + "operator": "is", + "type": "data_row_id", + }, + { + "ids": ["b1", "b2"], + "operator": "is", + "type": "batch", + }, + ] def test_ids_empty_filter(): client = MagicMock() filters = {"data_row_ids": [], "batch_ids": ["b1", "b2"]} - with pytest.raises(ValueError, - match="data_row_id filter expects a non-empty list."): + with pytest.raises( + ValueError, match="data_row_id filter expects a non-empty list." + ): build_filters(client, filters) def test_global_keys_filter(): client = MagicMock() filters = {"global_keys": ["id1", "id2"]} - assert build_filters(client, filters) == [{ - "ids": ["id1", "id2"], - "operator": "is", - "type": "global_key", - }] + assert build_filters(client, filters) == [ + { + "ids": ["id1", "id2"], + "operator": "is", + "type": "global_key", + } + ] def test_validations(): @@ -44,8 +50,7 @@ def test_validations(): "data_row_ids": ["id1", "id2"], } with pytest.raises( - ValueError, - match= - "data_rows and global_keys cannot both be present in export filters" + ValueError, + match="data_rows and global_keys cannot both be present in export filters", ): build_filters(client, filters) diff --git a/libs/labelbox/tests/unit/test_unit_label_import.py b/libs/labelbox/tests/unit/test_unit_label_import.py index feff4694c..b386a664d 100644 --- a/libs/labelbox/tests/unit/test_unit_label_import.py +++ b/libs/labelbox/tests/unit/test_unit_label_import.py @@ -13,27 +13,20 @@ def test_should_warn_user_about_unsupported_confidence(): { "uuid": "b862c586-8614-483c-b5e6-82810f70cac0", "schemaId": "ckrazcueb16og0z6609jj7y3y", - "dataRow": { - "id": "ckrazctum0z8a0ybc0b0o0g0v" - }, + "dataRow": {"id": "ckrazctum0z8a0ybc0b0o0g0v"}, "confidence": 0.851, - "bbox": { - "top": 1352, - "left": 2275, - "height": 350, - "width": 139 - } + "bbox": {"top": 1352, "left": 2275, "height": 350, "width": 139}, }, ] - with patch.object(LabelImport, '_create_label_import_from_bytes'): - with patch.object(logger, 'warning') as warning_mock: - LabelImport.create_from_objects(client=MagicMock(), - project_id=id, - name=id, - labels=labels) + with patch.object(LabelImport, "_create_label_import_from_bytes"): + with patch.object(logger, "warning") as warning_mock: + LabelImport.create_from_objects( + client=MagicMock(), project_id=id, name=id, labels=labels + ) warning_mock.assert_called_once() "Confidence scores are not supported in Label Import" in warning_mock.call_args_list[ - 0].args[0] + 0 + ].args[0] def test_invalid_labels_format(): @@ -43,19 +36,11 @@ def test_invalid_labels_format(): label = { "uuid": "b862c586-8614-483c-b5e6-82810f70cac0", "schemaId": "ckrazcueb16og0z6609jj7y3y", - "dataRow": { - "id": "ckrazctum0z8a0ybc0b0o0g0v" - }, - "bbox": { - "top": 1352, - "left": 2275, - "height": 350, - "width": 139 - } + "dataRow": {"id": "ckrazctum0z8a0ybc0b0o0g0v"}, + "bbox": {"top": 1352, "left": 2275, "height": 350, "width": 139}, } - with patch.object(LabelImport, '_create_label_import_from_bytes'): + with patch.object(LabelImport, "_create_label_import_from_bytes"): with pytest.raises(TypeError): - LabelImport.create_from_objects(client=MagicMock(), - project_id=id, - name=id, - labels=label) + LabelImport.create_from_objects( + client=MagicMock(), project_id=id, name=id, labels=label + ) diff --git a/libs/labelbox/tests/unit/test_unit_ontology.py b/libs/labelbox/tests/unit/test_unit_ontology.py index ac53827c6..0566ad623 100644 --- a/libs/labelbox/tests/unit/test_unit_ontology.py +++ b/libs/labelbox/tests/unit/test_unit_ontology.py @@ -5,183 +5,187 @@ from itertools import product _SAMPLE_ONTOLOGY = { - "tools": [{ - "schemaNodeId": None, - "featureSchemaId": None, - "required": False, - "name": "poly", - "color": "#FF0000", - "tool": "polygon", - "classifications": [] - }, { - "schemaNodeId": None, - "featureSchemaId": None, - "required": False, - "name": "segment", - "color": "#FF0000", - "tool": "superpixel", - "classifications": [] - }, { - "schemaNodeId": - None, - "featureSchemaId": - None, - "required": - False, - "name": - "bbox", - "color": - "#FF0000", - "tool": - "rectangle", - "classifications": [{ - "schemaNodeId": - None, - "featureSchemaId": - None, - "required": - True, - "instructions": - "nested classification", - "name": - "nested classification", - "type": - "radio", - 'uiMode': - "searchable", - "options": [{ - "schemaNodeId": - None, - "featureSchemaId": - None, - "label": - "first", - "value": - "first", - "options": [{ + "tools": [ + { + "schemaNodeId": None, + "featureSchemaId": None, + "required": False, + "name": "poly", + "color": "#FF0000", + "tool": "polygon", + "classifications": [], + }, + { + "schemaNodeId": None, + "featureSchemaId": None, + "required": False, + "name": "segment", + "color": "#FF0000", + "tool": "superpixel", + "classifications": [], + }, + { + "schemaNodeId": None, + "featureSchemaId": None, + "required": False, + "name": "bbox", + "color": "#FF0000", + "tool": "rectangle", + "classifications": [ + { + "schemaNodeId": None, + "featureSchemaId": None, + "required": True, + "instructions": "nested classification", + "name": "nested classification", + "type": "radio", + "uiMode": "searchable", + "options": [ + { + "schemaNodeId": None, + "featureSchemaId": None, + "label": "first", + "value": "first", + "options": [ + { + "schemaNodeId": None, + "featureSchemaId": None, + "required": False, + "instructions": "nested nested text", + "name": "nested nested text", + "type": "text", + "options": [], + } + ], + }, + { + "schemaNodeId": None, + "featureSchemaId": None, + "label": "second", + "value": "second", + "options": [], + }, + ], + }, + { "schemaNodeId": None, "featureSchemaId": None, - "required": False, - "instructions": "nested nested text", - "name": "nested nested text", + "required": True, + "instructions": "nested text", + "name": "nested text", "type": "text", - "options": [] - }] - }, { - "schemaNodeId": None, - "featureSchemaId": None, - "label": "second", - "value": "second", - "options": [] - }] - }, { + "options": [], + }, + ], + }, + { "schemaNodeId": None, "featureSchemaId": None, - "required": True, - "instructions": "nested text", - "name": "nested text", - "type": "text", - "options": [] - }] - }, { - "schemaNodeId": None, - "featureSchemaId": None, - "required": False, - "name": "dot", - "color": "#FF0000", - "tool": "point", - "classifications": [] - }, { - "schemaNodeId": None, - "featureSchemaId": None, - "required": False, - "name": "polyline", - "color": "#FF0000", - "tool": "line", - "classifications": [] - }, { - "schemaNodeId": None, - "featureSchemaId": None, - "required": False, - "name": "ner", - "color": "#FF0000", - "tool": "named-entity", - "classifications": [] - }], - "classifications": [{ - "schemaNodeId": - None, - "featureSchemaId": - None, - "required": - True, - "instructions": - "This is a question.", - "name": - "This is a question.", - "type": - "radio", - "scope": - "global", - 'uiMode': - "searchable", - "options": [{ + "required": False, + "name": "dot", + "color": "#FF0000", + "tool": "point", + "classifications": [], + }, + { "schemaNodeId": None, "featureSchemaId": None, - "label": "yes", - "value": "definitely yes", - "options": [] - }, { + "required": False, + "name": "polyline", + "color": "#FF0000", + "tool": "line", + "classifications": [], + }, + { "schemaNodeId": None, "featureSchemaId": None, - "label": "no", - "value": "definitely not", - "options": [] - }] - }] + "required": False, + "name": "ner", + "color": "#FF0000", + "tool": "named-entity", + "classifications": [], + }, + ], + "classifications": [ + { + "schemaNodeId": None, + "featureSchemaId": None, + "required": True, + "instructions": "This is a question.", + "name": "This is a question.", + "type": "radio", + "scope": "global", + "uiMode": "searchable", + "options": [ + { + "schemaNodeId": None, + "featureSchemaId": None, + "label": "yes", + "value": "definitely yes", + "options": [], + }, + { + "schemaNodeId": None, + "featureSchemaId": None, + "label": "no", + "value": "definitely not", + "options": [], + }, + ], + } + ], } @pytest.mark.parametrize("tool_type", list(Tool.Type)) def test_create_tool(tool_type) -> None: t = Tool(tool=tool_type, name="tool") - assert (t.tool == tool_type) + assert t.tool == tool_type @pytest.mark.parametrize("class_type", list(Classification.Type)) def test_create_classification(class_type) -> None: c = Classification(class_type=class_type, name="classification") - assert (c.class_type == class_type) + assert c.class_type == class_type + -@pytest.mark.parametrize("ui_mode_type, class_type", list(product(list(Classification.UIMode), list(Classification.Type)))) +@pytest.mark.parametrize( + "ui_mode_type, class_type", + list(product(list(Classification.UIMode), list(Classification.Type))), +) def test_create_classification_with_ui_mode(ui_mode_type, class_type) -> None: - c = Classification(name="classification", class_type=class_type, ui_mode=ui_mode_type) - assert (c.ui_mode == ui_mode_type) + c = Classification( + name="classification", class_type=class_type, ui_mode=ui_mode_type + ) + assert c.ui_mode == ui_mode_type -@pytest.mark.parametrize("value, expected_value, typing", - [(3, 3, int), ("string", "string", str)]) +@pytest.mark.parametrize( + "value, expected_value, typing", [(3, 3, int), ("string", "string", str)] +) def test_create_option_with_value(value, expected_value, typing) -> None: o = Option(value=value) - assert (o.value == expected_value) - assert (o.value == o.label) + assert o.value == expected_value + assert o.value == o.label -@pytest.mark.parametrize("value, label, expected_value, typing", - [(3, 2, 3, int), - ("string", "another string", "string", str)]) -def test_create_option_with_value_and_label(value, label, expected_value, - typing) -> None: +@pytest.mark.parametrize( + "value, label, expected_value, typing", + [(3, 2, 3, int), ("string", "another string", "string", str)], +) +def test_create_option_with_value_and_label( + value, label, expected_value, typing +) -> None: o = Option(value=value, label=label) - assert (o.value == expected_value) + assert o.value == expected_value assert o.value != o.label assert isinstance(o.value, typing) def test_create_empty_ontology() -> None: o = OntologyBuilder() - assert (o.tools == []) - assert (o.classifications == []) + assert o.tools == [] + assert o.classifications == [] def test_add_ontology_tool() -> None: @@ -193,7 +197,7 @@ def test_add_ontology_tool() -> None: assert len(o.tools) == 2 for tool in o.tools: - assert (type(tool) == Tool) + assert type(tool) == Tool with pytest.raises(InconsistentOntologyException) as exc: o.add_tool(Tool(tool=Tool.Type.BBOX, name="bounding box")) @@ -203,19 +207,22 @@ def test_add_ontology_tool() -> None: def test_add_ontology_classification() -> None: o = OntologyBuilder() o.add_classification( - Classification(class_type=Classification.Type.TEXT, name="text")) + Classification(class_type=Classification.Type.TEXT, name="text") + ) second_classification = Classification( - class_type=Classification.Type.CHECKLIST, name="checklist") + class_type=Classification.Type.CHECKLIST, name="checklist" + ) o.add_classification(second_classification) assert len(o.classifications) == 2 for classification in o.classifications: - assert (type(classification) == Classification) + assert type(classification) == Classification with pytest.raises(InconsistentOntologyException) as exc: o.add_classification( - Classification(class_type=Classification.Type.TEXT, name="text")) + Classification(class_type=Classification.Type.TEXT, name="text") + ) assert "Duplicate classification name" in str(exc.value) @@ -253,8 +260,9 @@ def test_option_add_option() -> None: def test_ontology_asdict() -> None: - assert OntologyBuilder.from_dict( - _SAMPLE_ONTOLOGY).asdict() == _SAMPLE_ONTOLOGY + assert ( + OntologyBuilder.from_dict(_SAMPLE_ONTOLOGY).asdict() == _SAMPLE_ONTOLOGY + ) def test_classification_using_instructions_instead_of_name_shows_warning(): diff --git a/libs/labelbox/tests/unit/test_unit_ontology_kind.py b/libs/labelbox/tests/unit/test_unit_ontology_kind.py index 51e2cf214..54cec0812 100644 --- a/libs/labelbox/tests/unit/test_unit_ontology_kind.py +++ b/libs/labelbox/tests/unit/test_unit_ontology_kind.py @@ -1,4 +1,8 @@ -from labelbox.schema.ontology_kind import OntologyKind, EditorTaskType, EditorTaskTypeMapper +from labelbox.schema.ontology_kind import ( + OntologyKind, + EditorTaskType, + EditorTaskTypeMapper, +) from labelbox.schema.media_type import MediaType @@ -6,17 +10,20 @@ def test_ontology_kind_conversions_from_editor_task_type(): ontology_kind = OntologyKind.ModelEvaluation media_type = MediaType.Conversational editor_task_type = EditorTaskTypeMapper.to_editor_task_type( - ontology_kind, media_type) + ontology_kind, media_type + ) assert editor_task_type == EditorTaskType.ModelChatEvaluation ontology_kind = OntologyKind.Missing media_type = MediaType.Image editor_task_type = EditorTaskTypeMapper.to_editor_task_type( - ontology_kind, media_type) + ontology_kind, media_type + ) assert editor_task_type == EditorTaskType.Missing ontology_kind = OntologyKind.ModelEvaluation media_type = MediaType.Video editor_task_type = EditorTaskTypeMapper.to_editor_task_type( - ontology_kind, media_type) + ontology_kind, media_type + ) assert editor_task_type == EditorTaskType.Missing diff --git a/libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py b/libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py index f9f9a0959..7f6d29d5a 100644 --- a/libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py +++ b/libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py @@ -22,8 +22,11 @@ def test_validate_labeling_parameter_overrides_invalid_data(): def test_validate_labeling_parameter_overrides_invalid_priority(): mock_data_row = MagicMock(spec=DataRow) mock_data_row.uid = "abc" - data = [(mock_data_row, "invalid"), (UniqueId("efg"), 2), - (GlobalKey("hij"), 3)] + data = [ + (mock_data_row, "invalid"), + (UniqueId("efg"), 2), + (GlobalKey("hij"), 3), + ] with pytest.raises(TypeError): validate_labeling_parameter_overrides(data) @@ -31,7 +34,10 @@ def test_validate_labeling_parameter_overrides_invalid_priority(): def test_validate_labeling_parameter_overrides_invalid_tuple_length(): mock_data_row = MagicMock(spec=DataRow) mock_data_row.uid = "abc" - data = [(mock_data_row, "invalid"), (UniqueId("efg"), 2), - (GlobalKey("hij"))] + data = [ + (mock_data_row, "invalid"), + (UniqueId("efg"), 2), + (GlobalKey("hij")), + ] with pytest.raises(TypeError): validate_labeling_parameter_overrides(data) diff --git a/libs/labelbox/tests/unit/test_unit_query.py b/libs/labelbox/tests/unit/test_unit_query.py index 12db00d2b..83bfeff8a 100644 --- a/libs/labelbox/tests/unit/test_unit_query.py +++ b/libs/labelbox/tests/unit/test_unit_query.py @@ -24,13 +24,15 @@ def test_query_where(): assert q.startswith("x(where: {name_gt: $param_0}){") assert p == {"param_0": ("name", Project.name)} - q, p = query.Query("x", Project, - (Project.name != "name") & (Project.uid <= 42)).format() + q, p = query.Query( + "x", Project, (Project.name != "name") & (Project.uid <= 42) + ).format() assert q.startswith( - "x(where: {AND: [{name_not: $param_0}, {id_lte: $param_1}]}") + "x(where: {AND: [{name_not: $param_0}, {id_lte: $param_1}]}" + ) assert p == { "param_0": ("name", Project.name), - "param_1": (42, Project.uid) + "param_1": (42, Project.uid), } @@ -38,8 +40,9 @@ def test_query_param_declaration(): q, _ = query.Query("x", Project, Project.name > "name").format_top("y") assert q.startswith("query yPyApi($param_0: String!){x") - q, _ = query.Query("x", Project, (Project.name > "name") & - (Project.uid == 42)).format_top("y") + q, _ = query.Query( + "x", Project, (Project.name > "name") & (Project.uid == 42) + ).format_top("y") assert q.startswith("query yPyApi($param_0: String!, $param_1: ID!){x") diff --git a/libs/labelbox/tests/unit/test_unit_search_filters.py b/libs/labelbox/tests/unit/test_unit_search_filters.py index eba8d4db8..b2230bb7f 100644 --- a/libs/labelbox/tests/unit/test_unit_search_filters.py +++ b/libs/labelbox/tests/unit/test_unit_search_filters.py @@ -1,37 +1,68 @@ from datetime import datetime from labelbox.schema.labeling_service import LabelingServiceStatus -from labelbox.schema.search_filters import IntegerValue, RangeDateTimeOperatorWithSingleValue, RangeOperatorWithSingleValue, DateRange, RangeOperatorWithValue, DateRangeValue, DateValue, IdOperator, OperationType, OrganizationFilter, ProjectStageFilter, SharedWithOrganizationFilter, TagFilter, TaskCompletedCountFilter, TaskRemainingCountFilter, WorkforceRequestedDateFilter, WorkforceRequestedDateRangeFilter, WorkforceStageUpdatedFilter, WorkforceStageUpdatedRangeFilter, WorkspaceFilter, build_search_filter +from labelbox.schema.search_filters import ( + IntegerValue, + RangeDateTimeOperatorWithSingleValue, + RangeOperatorWithSingleValue, + DateRange, + RangeOperatorWithValue, + DateRangeValue, + DateValue, + IdOperator, + OperationType, + OrganizationFilter, + ProjectStageFilter, + SharedWithOrganizationFilter, + TagFilter, + TaskCompletedCountFilter, + TaskRemainingCountFilter, + WorkforceRequestedDateFilter, + WorkforceRequestedDateRangeFilter, + WorkforceStageUpdatedFilter, + WorkforceStageUpdatedRangeFilter, + WorkspaceFilter, + build_search_filter, +) from labelbox.utils import format_iso_datetime import pytest def test_id_filters(): filters = [ - OrganizationFilter(operator=IdOperator.Is, - values=["clphb4vd7000cd2wv1ktu5cwa"]), - SharedWithOrganizationFilter(operator=IdOperator.Is, - values=["clphb4vd7000cd2wv1ktu5cwa"]), - WorkspaceFilter(operator=IdOperator.Is, - values=["clphb4vd7000cd2wv1ktu5cwa"]), + OrganizationFilter( + operator=IdOperator.Is, values=["clphb4vd7000cd2wv1ktu5cwa"] + ), + SharedWithOrganizationFilter( + operator=IdOperator.Is, values=["clphb4vd7000cd2wv1ktu5cwa"] + ), + WorkspaceFilter( + operator=IdOperator.Is, values=["clphb4vd7000cd2wv1ktu5cwa"] + ), TagFilter(operator=IdOperator.Is, values=["cls1vkrw401ab072vg2pq3t5d"]), - ProjectStageFilter(operator=IdOperator.Is, - values=[LabelingServiceStatus.Requested]), + ProjectStageFilter( + operator=IdOperator.Is, values=[LabelingServiceStatus.Requested] + ), ] - assert build_search_filter( - filters - ) == '[{type: "organization_id", operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"]}, {type: "shared_with_organizations", operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"]}, {type: "workspace", operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"]}, {type: "tag", operator: "is", values: ["cls1vkrw401ab072vg2pq3t5d"]}, {type: "stage", operator: "is", values: ["REQUESTED"]}]' + assert ( + build_search_filter(filters) + == '[{type: "organization_id", operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"]}, {type: "shared_with_organizations", operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"]}, {type: "workspace", operator: "is", values: ["clphb4vd7000cd2wv1ktu5cwa"]}, {type: "tag", operator: "is", values: ["cls1vkrw401ab072vg2pq3t5d"]}, {type: "stage", operator: "is", values: ["REQUESTED"]}]' + ) def test_stage_filter_with_invalid_values(): with pytest.raises( - ValueError, - match="is not a valid value for ProjectStageFilter") as e: - _ = ProjectStageFilter(operator=IdOperator.Is, - values=[ - LabelingServiceStatus.Requested, - LabelingServiceStatus.Missing - ]), + ValueError, match="is not a valid value for ProjectStageFilter" + ) as e: + _ = ( + ProjectStageFilter( + operator=IdOperator.Is, + values=[ + LabelingServiceStatus.Requested, + LabelingServiceStatus.Missing, + ], + ), + ) def test_date_filters(): @@ -39,46 +70,80 @@ def test_date_filters(): local_time_end = datetime.strptime("2025-01-01", "%Y-%m-%d") filters = [ - WorkforceRequestedDateFilter(value=DateValue( - operator=RangeDateTimeOperatorWithSingleValue.GreaterThanOrEqual, - value=local_time_start)), - WorkforceStageUpdatedFilter(value=DateValue( - operator=RangeDateTimeOperatorWithSingleValue.LessThanOrEqual, - value=local_time_end)), + WorkforceRequestedDateFilter( + value=DateValue( + operator=RangeDateTimeOperatorWithSingleValue.GreaterThanOrEqual, + value=local_time_start, + ) + ), + WorkforceStageUpdatedFilter( + value=DateValue( + operator=RangeDateTimeOperatorWithSingleValue.LessThanOrEqual, + value=local_time_end, + ) + ), ] expected_start = format_iso_datetime(local_time_start) expected_end = format_iso_datetime(local_time_end) - expected = '[{type: "workforce_requested_at", value: {operator: "GREATER_THAN_OR_EQUAL", value: "' + expected_start + '"}}, {type: "workforce_stage_updated_at", value: {operator: "LESS_THAN_OR_EQUAL", value: "' + expected_end + '"}}]' + expected = ( + '[{type: "workforce_requested_at", value: {operator: "GREATER_THAN_OR_EQUAL", value: "' + + expected_start + + '"}}, {type: "workforce_stage_updated_at", value: {operator: "LESS_THAN_OR_EQUAL", value: "' + + expected_end + + '"}}]' + ) assert build_search_filter(filters) == expected def test_date_range_filters(): filters = [ - WorkforceRequestedDateRangeFilter(value=DateRangeValue( - operator=RangeOperatorWithValue.Between, - value=DateRange(min=datetime.strptime("2024-01-01T00:00:00-0800", - "%Y-%m-%dT%H:%M:%S%z"), - max=datetime.strptime("2025-01-01T00:00:00-0800", - "%Y-%m-%dT%H:%M:%S%z")))), - WorkforceStageUpdatedRangeFilter(value=DateRangeValue( - operator=RangeOperatorWithValue.Between, - value=DateRange(min=datetime.strptime("2024-01-01T00:00:00-0800", - "%Y-%m-%dT%H:%M:%S%z"), - max=datetime.strptime("2025-01-01T00:00:00-0800", - "%Y-%m-%dT%H:%M:%S%z")))), + WorkforceRequestedDateRangeFilter( + value=DateRangeValue( + operator=RangeOperatorWithValue.Between, + value=DateRange( + min=datetime.strptime( + "2024-01-01T00:00:00-0800", "%Y-%m-%dT%H:%M:%S%z" + ), + max=datetime.strptime( + "2025-01-01T00:00:00-0800", "%Y-%m-%dT%H:%M:%S%z" + ), + ), + ) + ), + WorkforceStageUpdatedRangeFilter( + value=DateRangeValue( + operator=RangeOperatorWithValue.Between, + value=DateRange( + min=datetime.strptime( + "2024-01-01T00:00:00-0800", "%Y-%m-%dT%H:%M:%S%z" + ), + max=datetime.strptime( + "2025-01-01T00:00:00-0800", "%Y-%m-%dT%H:%M:%S%z" + ), + ), + ) + ), ] - assert build_search_filter( - filters - ) == '[{type: "workforce_requested_at", value: {operator: "BETWEEN", value: {min: "2024-01-01T08:00:00Z", max: "2025-01-01T08:00:00Z"}}}, {type: "workforce_stage_updated_at", value: {operator: "BETWEEN", value: {min: "2024-01-01T08:00:00Z", max: "2025-01-01T08:00:00Z"}}}]' + assert ( + build_search_filter(filters) + == '[{type: "workforce_requested_at", value: {operator: "BETWEEN", value: {min: "2024-01-01T08:00:00Z", max: "2025-01-01T08:00:00Z"}}}, {type: "workforce_stage_updated_at", value: {operator: "BETWEEN", value: {min: "2024-01-01T08:00:00Z", max: "2025-01-01T08:00:00Z"}}}]' + ) def test_task_count_filters(): filters = [ - TaskCompletedCountFilter(value=IntegerValue( - operator=RangeOperatorWithSingleValue.GreaterThanOrEqual, value=1)), - TaskRemainingCountFilter(value=IntegerValue( - operator=RangeOperatorWithSingleValue.LessThanOrEqual, value=10)), + TaskCompletedCountFilter( + value=IntegerValue( + operator=RangeOperatorWithSingleValue.GreaterThanOrEqual, + value=1, + ) + ), + TaskRemainingCountFilter( + value=IntegerValue( + operator=RangeOperatorWithSingleValue.LessThanOrEqual, value=10 + ) + ), ] expected = '[{type: "task_completed_count", value: {operator: "GREATER_THAN_OR_EQUAL", value: 1}}, {type: "task_remaining_count", value: {operator: "LESS_THAN_OR_EQUAL", value: 10}}]' diff --git a/libs/labelbox/tests/unit/test_unit_webhook.py b/libs/labelbox/tests/unit/test_unit_webhook.py index 405955ce6..ae1b6884d 100644 --- a/libs/labelbox/tests/unit/test_unit_webhook.py +++ b/libs/labelbox/tests/unit/test_unit_webhook.py @@ -13,8 +13,7 @@ def test_webhook_create_with_no_secret(rand_gen): with pytest.raises(ValueError) as exc_info: Webhook.create(client, topics, url, secret, project) - assert str(exc_info.value) == \ - "Secret must be a non-empty string." + assert str(exc_info.value) == "Secret must be a non-empty string." def test_webhook_create_with_no_topics(rand_gen): @@ -26,8 +25,7 @@ def test_webhook_create_with_no_topics(rand_gen): with pytest.raises(ValueError) as exc_info: Webhook.create(client, topics, url, secret, project) - assert str(exc_info.value) == \ - "Topics must be a non-empty list." + assert str(exc_info.value) == "Topics must be a non-empty list." def test_webhook_create_with_no_url(rand_gen): @@ -39,5 +37,4 @@ def test_webhook_create_with_no_url(rand_gen): with pytest.raises(ValueError) as exc_info: Webhook.create(client, topics, url, secret, project) - assert str(exc_info.value) == \ - "URL must be a non-empty string." + assert str(exc_info.value) == "URL must be a non-empty string." diff --git a/libs/labelbox/tests/unit/test_utils.py b/libs/labelbox/tests/unit/test_utils.py index dfd72c335..969f3a46b 100644 --- a/libs/labelbox/tests/unit/test_utils.py +++ b/libs/labelbox/tests/unit/test_utils.py @@ -1,21 +1,34 @@ import pytest -from labelbox.utils import format_iso_datetime, format_iso_from_string, sentence_case +from labelbox.utils import ( + format_iso_datetime, + format_iso_from_string, + sentence_case, +) -@pytest.mark.parametrize('datetime_str, expected_datetime_str', - [('2011-11-04T00:05:23Z', '2011-11-04T00:05:23Z'), - ('2011-11-04T00:05:23+00:00', '2011-11-04T00:05:23Z'), - ('2011-11-04T00:05:23+05:00', '2011-11-03T19:05:23Z'), - ('2011-11-04T00:05:23', '2011-11-04T00:05:23Z')]) +@pytest.mark.parametrize( + "datetime_str, expected_datetime_str", + [ + ("2011-11-04T00:05:23Z", "2011-11-04T00:05:23Z"), + ("2011-11-04T00:05:23+00:00", "2011-11-04T00:05:23Z"), + ("2011-11-04T00:05:23+05:00", "2011-11-03T19:05:23Z"), + ("2011-11-04T00:05:23", "2011-11-04T00:05:23Z"), + ], +) def test_datetime_parsing(datetime_str, expected_datetime_str): # NOTE I would normally not take 'expected' using another function from sdk code, but in this case this is exactly the usage in _validate_parse_datetime - assert format_iso_datetime( - format_iso_from_string(datetime_str)) == expected_datetime_str + assert ( + format_iso_datetime(format_iso_from_string(datetime_str)) + == expected_datetime_str + ) @pytest.mark.parametrize( - 'str, expected_str', - [('AUDIO', 'Audio'), - ('LLM_PROMPT_RESPONSE_CREATION', 'Llm prompt response creation')]) + "str, expected_str", + [ + ("AUDIO", "Audio"), + ("LLM_PROMPT_RESPONSE_CREATION", "Llm prompt response creation"), + ], +) def test_sentence_case(str, expected_str): assert sentence_case(str) == expected_str diff --git a/libs/labelbox/tests/utils.py b/libs/labelbox/tests/utils.py index 6fa2a8d8d..595fa0c76 100644 --- a/libs/labelbox/tests/utils.py +++ b/libs/labelbox/tests/utils.py @@ -14,9 +14,9 @@ def remove_keys_recursive(d, keys): # NOTE this uses quite a primitive check for cuids but I do not think it is worth coming up with a better one # Also this function is NOT written with performance in mind, good for small to mid size dicts like we have in our test def rename_cuid_key_recursive(d): - new_key = '' + new_key = "" for k in list(d.keys()): - if len(k) == 25 and not k.isalpha(): #primitive check for cuid + if len(k) == 25 and not k.isalpha(): # primitive check for cuid d[new_key] = d.pop(k) for k, v in d.items(): if isinstance(v, dict): @@ -27,4 +27,4 @@ def rename_cuid_key_recursive(d): rename_cuid_key_recursive(i) -INTEGRATION_SNAPSHOT_DIRECTORY = 'tests/integration/snapshots' +INTEGRATION_SNAPSHOT_DIRECTORY = "tests/integration/snapshots" From 8d0da3ed215baa848e76bda90cf7b78c2b447317 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Wed, 11 Sep 2024 15:07:57 -0500 Subject: [PATCH 6/8] fixed error --- .../src/labelbox/schema/bulk_import_request.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/bulk_import_request.py b/libs/labelbox/src/labelbox/schema/bulk_import_request.py index 7caa2c6eb..44ac7cd6a 100644 --- a/libs/labelbox/src/labelbox/schema/bulk_import_request.py +++ b/libs/labelbox/src/labelbox/schema/bulk_import_request.py @@ -787,10 +787,8 @@ def validate_feature_schemas( # A union with custom construction logic to improve error messages class NDClassification( SpecialUnion, - Type[ - Union[ # type: ignore - NDText, NDRadio, NDChecklist - ] + Type[ # type: ignore + Union[NDText, NDRadio, NDChecklist] ], ): ... @@ -966,8 +964,8 @@ class NDMask(NDBaseTool): # A union with custom construction logic to improve error messages class NDTool( SpecialUnion, - Type[ - Union[ # type: ignore + Type[ # type: ignore + Union[ NDMask, NDTextEntity, NDPoint, @@ -981,10 +979,8 @@ class NDTool( class NDAnnotation( SpecialUnion, - Type[ - Union[ # type: ignore - NDTool, NDClassification - ] + Type[ # type: ignore + Union[NDTool, NDClassification] ], ): @classmethod From b511a12f3de9ab40891ae0b608baa7e24fa90197 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Wed, 11 Sep 2024 18:35:57 -0500 Subject: [PATCH 7/8] feedback --- .github/workflows/python-package-develop.yml | 2 +- .github/workflows/python-package-shared.yml | 7 ++----- libs/labelbox/pyproject.toml | 3 ++- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/.github/workflows/python-package-develop.yml b/.github/workflows/python-package-develop.yml index fc2c24e54..05eff5dc4 100644 --- a/.github/workflows/python-package-develop.yml +++ b/.github/workflows/python-package-develop.yml @@ -177,4 +177,4 @@ jobs: linux/arm64 tags: | - ${{ env.CONTAINER_IMAGE }}:${{ github.sha }} + ${{ env.CONTAINER_IMAGE }}:${{ github.sha }} \ No newline at end of file diff --git a/.github/workflows/python-package-shared.yml b/.github/workflows/python-package-shared.yml index acd30b299..4311020d8 100644 --- a/.github/workflows/python-package-shared.yml +++ b/.github/workflows/python-package-shared.yml @@ -18,7 +18,7 @@ on: test-env: required: true type: string - fixture-profile: + fixture-profile: required: true type: boolean @@ -36,9 +36,6 @@ jobs: - name: Linting working-directory: libs/labelbox run: rye run lint - - name: Format - working-directory: libs/labelbox - run: rye fmt --check integration: runs-on: ubuntu-latest concurrency: @@ -81,4 +78,4 @@ jobs: run: | rye sync -f --features labelbox/data rye run unit -n 32 - rye run data -n 32 + rye run data -n 32 \ No newline at end of file diff --git a/libs/labelbox/pyproject.toml b/libs/labelbox/pyproject.toml index 771117a01..b8188e916 100644 --- a/libs/labelbox/pyproject.toml +++ b/libs/labelbox/pyproject.toml @@ -89,8 +89,9 @@ unit = "pytest tests/unit" # LABELBOX_TEST_BASE_URL="http://host.docker.internal:8080" \ integration = { cmd = "pytest tests/integration" } data = { cmd = "pytest tests/data" } +ruff-fmt-check = "rye fmt --check" mypy-lint = "mypy src --pretty --show-error-codes --non-interactive --install-types" -lint = { chain = ["mypy-lint"] } +lint = { chain = ["mypy-lint", "ruff-fmt-check"] } test = { chain = ["lint", "unit", "integration"] } [tool.hatch.metadata] From ecdcde3d0e67ec2fc6a956b3b77ff5de50880217 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Wed, 11 Sep 2024 18:44:04 -0500 Subject: [PATCH 8/8] feedback and merge --- .github/workflows/python-package-shared.yml | 7 ++----- libs/labelbox/pyproject.toml | 3 ++- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/.github/workflows/python-package-shared.yml b/.github/workflows/python-package-shared.yml index acd30b299..4311020d8 100644 --- a/.github/workflows/python-package-shared.yml +++ b/.github/workflows/python-package-shared.yml @@ -18,7 +18,7 @@ on: test-env: required: true type: string - fixture-profile: + fixture-profile: required: true type: boolean @@ -36,9 +36,6 @@ jobs: - name: Linting working-directory: libs/labelbox run: rye run lint - - name: Format - working-directory: libs/labelbox - run: rye fmt --check integration: runs-on: ubuntu-latest concurrency: @@ -81,4 +78,4 @@ jobs: run: | rye sync -f --features labelbox/data rye run unit -n 32 - rye run data -n 32 + rye run data -n 32 \ No newline at end of file diff --git a/libs/labelbox/pyproject.toml b/libs/labelbox/pyproject.toml index 771117a01..58ce3410a 100644 --- a/libs/labelbox/pyproject.toml +++ b/libs/labelbox/pyproject.toml @@ -89,8 +89,9 @@ unit = "pytest tests/unit" # LABELBOX_TEST_BASE_URL="http://host.docker.internal:8080" \ integration = { cmd = "pytest tests/integration" } data = { cmd = "pytest tests/data" } +rye-fmt-check = "rye fmt --check" mypy-lint = "mypy src --pretty --show-error-codes --non-interactive --install-types" -lint = { chain = ["mypy-lint"] } +lint = { chain = ["mypy-lint", "rye-fmt-check"] } test = { chain = ["lint", "unit", "integration"] } [tool.hatch.metadata]