diff --git a/README.md b/README.md index e1bc32c88..2ff16fbe3 100644 --- a/README.md +++ b/README.md @@ -72,10 +72,6 @@ You should be set! Running the snippet above should create a dataset called `Tes ## Contribution Guidelines We encourage anyone to contribute to this repository to help improve it. Please refer to [Contributing Guide](CONTRIBUTING.md) for detailed information on how to contribute. This guide also includes instructions for how to build and run the SDK locally. -## Develop with AI assistance -### Use the codebase as context for large language models -Using the [GPT repository loader](https://github.com/mpoon/gpt-repository-loader), we have created `lbx_prompt.txt` that contains data from all `.py` and `.md` files. The file has about 730k tokens. We recommend using Gemini 1.5 Pro with 1 million context length window. - ## Documentation The SDK is well-documented to help developers get started quickly and use the SDK effectively. Here are links to that documentation: diff --git a/lbx_prompt.txt b/lbx_prompt.txt deleted file mode 100644 index 19873114d..000000000 --- a/lbx_prompt.txt +++ /dev/null @@ -1,43549 +0,0 @@ -The following text is a Git repository with code. The structure of the text are sections that begin with ----, followed by a single line containing the file path and file name, followed by a variable amount of lines containing the file contents. The text representing the Git repository ends when the symbols --END-- are encounted. Any further text beyond --END-- are meant to be interpreted as instructions using the aforementioned Git repository as context. ----- -CHANGELOG.md -# Changelog -# Version 3.65.0 (2024-03-05) -## Notes -* Rerelease of 3.64.0 - -# Version 3.64.0 (2024-02-29) - -## Added -* `Client.get_catalog` Add catalog schema class. Catalog exports can now be made without creating a slice first -* `last_activity_at` filter added to export_v2, allowing users to specify a datetime window without a slice - -## Removed -* Review related WebhookDataSource topics - -## Notebooks -* Added get_catalog notebook -* Update custom metrics notebook -* Update notebooks for video and image annotation import - -# Version 3.63.0 (2024-02-19) -## Added -* Ability for users to install and use sdk with pydantic v.2.* while still maintaining support for pydantic v1.* -* `ModelRun` `export()` and `export_v2()` add model_run_details to support splits - -## Notebooks -* Add composite mask notebook - -# Version 3.62.0 (2024-02-12) -## Added -* Support custom metrics for predictions (all applicable annotation classes) -* `FoundryClient.run_app` Add data_row identifier validation for running foundry app -* `Client.get_error_status_code` Default to 500 error if a server error is unparseable instead of throwing an exception - -## Updated -* `DataRowMetadata, DataRowMetadataBatchResponse, _UpsertBatchDataRowMetadata` Make data_row_id and global_key optional in all schema types - -## Fixed -* `ExportTask.__str__` Fix returned type in ExportTask instance representation - -## Removed -* `Project.upsert_review_queue` - -## Notebooks -* Update notebooks to new export methods -* Add model slice notebook -* Added support for annotation import with img bytes -* Update user prompts for huggingface colab - -# Version 3.61.2 (2024-01-29) -## Added -* `ModelSlice.get_data_row_identifiers` for Foundry data rows - -## Fixed -* `ModelSlice.get_data_row_identifiers` scoping by model run id - -# Version 3.61.1 (2024-01-25) -## Fixed -* Removed export API limit (5000) - -# Version 3.61.0 (2024-01-22) -# Added -* `ModelSlice.get_data_row_identifiers` - * Fetches all data row ids and global keys for the model slice - * NOTE Foundry model slices are note supported yet -## Updated -* Updated exports v1 deprecation date to April 30th, 2024 -* Remove `streamable` param from export_v2 methods - -# Version 3.60.0 (2024-01-17) -## Added -* Get resource tags from a project -* Method to CatalogSlice to get data row identifiers (both uids and global keys) -* Added deprecation notice for the `upsert_review_queue` method in project -## Notebooks -* Update notebook for Project move_data_rows_to_task_queue -* Added notebook for model foundry -* Added notebook for migrating from Exports V1 to V2 - -# Version 3.59.0 (2024-01-05) -## Added -* Support set_labeling_parameter_overrides for global keys -* Support bulk_delete of data row metadata for global keys -* Support bulk_export of data row metadata for global keys -## Fixed -* Stop overwriting class annotations on prediction upload -* Prevent users from uploading video annotations over the API limit (5000) -* Make description optional for foundry app -## Notebooks -* Update notebooks for Project set_labeling_parameter_overrides add support for global keys - -# Version 3.58.1 (2023-12-15) -## Added -* Support to export all projects and all model runs to `export_v2` for a `dataset` and a `slice` -## Notebooks -* Update exports v2 notebook to include methods that return `ExportTask` - -# Version 3.58.0 (2023-12-11) -## Added -* `ontology_id` to the model app instantiation -* LLM data generation label types -* `run_foundry_app` to support running model foundry apps -* Two methods for sending data rows to any workflow task in a project, that can also include predictions from a model run, or annotations from a different project -## Fixed -* Documentation index for identifiables -## Removed -* Project.datasets and Datasets.projects methods as they have been deprecated -## Notebooks -* Added note books for Human labeling(GT/MAL/MEA) + data generation (GT/MAL) -* Remove relationship annotations from text and conversational imports - -# Version 3.57.0 (2023-11-30) -## Added -* Global key support for Project move_data_rows_to_task_queue -* Project name required for project creation -## Notebooks -* Updates to Image and Video notebook format -* Added additional byte array examples for Image/Video import and Image prediction import notebook -* Added a new LLM folder for new LLM import (MAL/MEA/Ground truth) - -# Version 3.56.0 (2023-11-21) -## Added -* Support for importing raster video masks from image bytes as a source -* Add new ExportTask class to handle streaming of exports -## Fixed -* Check for empty fields during webhook creation -## Notebooks -* Updates to use bytes array for masks (video, image), and add examples of multiple notations per frame (video) - -# Version 3.55.0 (2023-11-06) -## Fixed -* Fix the instantiation of `failed_data_row_ids` in Batch. This fix will address the issue with the `create_batch` method for more than 1,000 data rows. -* Improve Python type hints for the `data_rows()` method in the Dataset. -* Fix the `DataRowMetadataOntology` method `bulk_export()` to properly export global key(s). -* In the `DataRowMetadataOntology` method `update_enum_option`, provide a more descriptive error message when the enum option is not valid. - -# Version 3.54.1 (2023-10-17) -## Notebooks -* Revised the notebooks to update outdated examples when using `client.create_project()` to create a project - -# Version 3.54.0 (2023-10-10) -## Added -* Add exports v1 deprecation warning -* Create method in SDK to modify LPO priorities in bulk -## Removed -* Remove backoff library - -# Version 3.53.0 (2023-10-03) -## Added -* Remove LPO deprecation warning and allow greater range of priority values -* Add an sdk method to get data row by global key -* Disallow invalid quality modes during create_project -* Python 3.10 support -* Change return of dataset.create_data_rows() to Task -* Add new header to capture python version -## Notebooks -* Updated examples to match latest updates to SDK - -# Version 3.52.0 (2023-08-24) -## Added -* Added methods to create multiple batches for a project from a list of data rows -* Limit the number of data rows to be checked for processing status - -# Version 3.51.0 (2023-08-14) -## Added -* Added global keys to export v2 filters for project, dataset and DataRow -* Added workflow task status filtering for export v2 - - ## Notebooks -* Removed labels notebook, since almost all of the relevant methods in the notebook were not compatible with workflow paradigm. -* Updated project.ipynb to use batches not datasets - -# Version 3.50.0 (2023-08-04) -## Added - * Support batch_ids filter for projects in Exports v2 - * Added access_from field to project members to differentiate project-based roles from organization level roles - * Ability to use data_row_ids instead of the whole data row object for DataRow.export_V2() - * Cursor-based pagination for dataset.data_rows() - - ## Fixed - * client.get_projects() unable to fetch details for LLM projects - - ## Notebooks - * Improved the documentation for `examples/basics/custom_embeddings.ipynb` - * Updated the documentation for `examples/basics/data_row_metadata.ipynb` - * Added details about CRUD methods to `examples/basics/ontologies.ipynb` - -# Version 3.49.1 (2023-06-29) -## Fixed -* Removed numpy version lock that caused Python version >3.8 to download incompatible numpy version - -# Version 3.49.0 (2023-06-27) - -## Changed -* Improved batch creation logic when more than 1000 global keys provided - -## Notebooks -* Added example on how to access mark in export v2 -* Removed NDJSON library from `examples/basics/custom_embeddings.ipynb` -* Removed `queue_mode` property from `create_project()` method call. - -# Version 3.48.0 (2023-06-13) -## Added -* Support for ISO format to exports V2 date filters -* Support to specify confidence for all free-text annotations - -## Changed -* Removed backports library and replaced it with python dateutil package to parse iso strings - -## Notebooks -* Added predictions to model run example -* Added notebook to run yolov8 and sam on video and upload to LB -* Updated google colab notebooks to reflect raster segmentation tool being released on 6/13 -* Updated radio NDJSON annotations format to support confidence -* Added confidence to all free-text annotations (ndjson) -* Fixed issues with cv2 library rooting from the Geospatial notebook used a png map with a signed URL with an expired token - -# Version 3.47.1 (2023-05-24) -## Fixed -* Loading of the ndjson parser when optional [data] libraries (geojson etc.) are not installed - -# Version 3.47.0 (2023-05-23) -## Added -* Support for interpolated frames to export v2 - -## Changed -* Removed ndjson library and replaced it with a custom ndjson parser - -## Notebooks -* Removed confidence scores in annotations - video notebook -* Removed raster seg masks from video prediction -* Added export v2 example -* Added SAM and Labelbox connector notebook - -# Version 3.46.0 (2023-05-03) -## Added -* Global key support to DataRow Metadata `bulk_upsert()` function - -## Notebooks -* Removed dataset based projects from project setup notebook -* Updated all links to annotation import and prediction notebooks in examples README - -# Version 3.45.0 (2023-04-27) -## Changed -* Reduce threshold for async batch creation to 1000 data rows - -## Notebooks -* Added subclassifications to ontology notebook -* Added conversational and pdf predictions notebooks - -# Version 3.44.0 (2023-04-26) - -## Added -* `predictions` param for optionally exporting predictions in model run export v2 -* Limits on `model_run_ids` and `project_ids` on catalog export v2 params -* `WORKFLOW_ACTION` webhook topic -* Added `data_row_ids` filter for dataset and project export v2 - -## Fixed -* ISO timestamp parsing for datetime metadata -* Docstring typo for `client.delete_feature_schema_from_ontology()` - -## Notebooks -* Removed mention of embeddings metadata fields -* Fixed broken colab link on `examples/extras/classification-confusion-matrix.ipynb` -* Added free text classification example to video annotation import notebook -* Updated prediction_upload notebooks with Annotation Type examples - -# Version 3.43.0 (2023-04-05) - -## Added -* Nested object classifications to `VideoObjectAnnotation` -* Relationship Annotation Types -* Added `project_ids` and `model_run_ids` to params in all export_v2 functions - -## Fixed -* VideoMaskAnnotation annotation import - -## Notebooks -* Added DICOM annotation import notebook -* Added audio annotation import notebook -* Added HTML annotation import notebook -* Added relationship examples to annotation import notebooks -* Added global video classification example -* Added nested classification examples -* Added video mask example -* Added global key and LPOs to queue management notebook - -# Version 3.42.0 (2023-03-22) - -## Added -* Message based classifications with annotation types for conversations -* Video and raster segmentation annotation types -* Global key support to `ConversationEntity`, `DocumentEntity` and `DicomSegments` -* DICOM polyline annotation type -* Confidence attribute to classification annotations - -## Changed -* Increased metadata string size limit to 4096 chars -* Removed `deletedDataRowGlobalKey` from `get_data_row_ids_for_global_keys()` - -## Fixed -* Annotation data type coercion by Pydantic -* Error message when end point coordinates are smaller than start point coordinates -* Some typos in error messages - -## Notebooks -* Refactored video notebook to include annotation types -* Replaced data row ids with global keys in notebooks -* Replaced `create_data_row` with `create_data_rows` in notebooks - -# Version 3.41.0 (2023-03-15) - -## Added -* New data classes for creating labels: `AudioData`, `ConversationData`, `DicomData`, `DocumentData`, `HTMLData` -* New `DocumentEntity` annotation type class -* New parameter `last_activity_end` to `Project.export_labels()` - -## Notebooks -* Updated `annotation_import/pdf.ipynb` with example use of `DocumentEntity` class - -# Version 3.40.1 (2023-03-10) - -## Fixed -* Fixed issue where calling create_batch() on exported data rows wasn't working - -# Version 3.40.0 (2023-03-10) - -## Added -* Support Global keys to reference data rows in `Project.create_batch()`, `ModelRun.assign_data_rows_to_split()` -* Support upserting labels via project_id in `model_run.upsert_labels()` -* `media_type_override` param to export_v2 -* `last_activity_at` and `label_created_at` params to export_v2 -* New client method `is_feature_schema_archived()` -* New client method `unarchive_feature_schema_node()` -* New client method `delete_feature_schema_from_ontology()` - -## Changed -* Removed default task names for export_v2 - -## Fixed -* process_label() for COCO panoptic dataset - -## Notebooks -* Updated `annotation_import/pdf.ipynb` with more examples -* Added `integrations/huggingface/huggingface.ipynb` -* Fixed broken links for detectron notebooks in README -* Added Dataset QueueMode during project creation in `integrations/detectron2/coco_object.ipynb` -* Removed metadata and updated ontology in `annotation_import/text.ipynb` -* Removed confidence scores in `annotation_import/image.ipynb` -* Updated custom embedding tutorial links in `basics/data_row_metadata.ipynb` - -# Version 3.39.0 (2023-02-28) -## Added -* New method `Project.task_queues()` to obtain the task queues for a project. -* New method `Project.move_data_rows_to_task_queue()` for moving data rows to a specified task queue. -* Added more descriptive error messages for metadata operations -* Added `Task.errors_url` for async tasks that return errors as separate file (e.g. `export_v2`) -* Upsert data rows to model runs using global keys - -## Changed -* Updated `ProjectExportParams.labels` to `ProjectExportParams.label_details` -* Removed `media_attributes` from `DataRowParams` -* Added deprecation warnings for `LabelList` and removed its usage -* Removed unused arguments in `Project.export_v2` and `ModelRun.export_v2` -* In `Project.label_generator()`, we now filter skipped labels for project with videos - -## Notebooks -* Fixed `examples/label_export/images.ipynb` notebook metadata -* Removed unused `lb_serializer` imports -* Removed uuid generation in NDJson annotation payloads, as it is now optional -* Removed custom embeddings usage in `examples/basics/data_row_metadata.ipynb` -* New notebook `examples/basics/custom_embeddings.ipynb` for custom embeddings -* Updated `examples/annotation_import/text.ipynb` to use `TextData` and specify Text media type - -# Version 3.38.0 (2023-02-15) - -## Added -* All imports are available via `import labelbox as lb` and `import labelbox.types as lb_types`. -* Attachment_name support to create_attachment() - -## Changed -* `LabelImport.create_from_objects()`, `MALPredictionImport.create_from_objects()`, `MEAPredictionImport.create_from_objects()`, `Project.upload_annotations()`, `ModelRun.add_predictions()` now support Python Types for annotations. - -## Notebooks -* Removed NDJsonConverter from example notebooks -* Simplified imports in all notebooks -* Fixed nested classification in examples/annotation_import/image.ipynb -* Ontology (instructions --> name) - -# Version 3.37.0 (2023-02-08) -## Added -* New `last_activity_start` param to `project.export_labels()` for filtering which labels are exported. See docstring for more on how this works. - -## Changed -* Rename `Classification.instructions` to `Classification.name` - -## Fixed -* Retry connection timeouts - -# Version 3.36.1 (2023-01-24) -### Fixed -* `confidence` is now optional for TextEntity - -# Version 3.36.0 (2023-01-23) -### Fixed -* `confidence` attribute is now supported for TextEntity and Line predictions - -# Version 3.35.0 (2023-01-18) -### Fixed -* Retry 520 errors when uploading files - -# Version 3.34.0 (2022-12-22) -### Added -* Added `get_by_name()` method to MetadataOntology object to access both custom and reserved metadata by name. -* Added support for adding metadata by name when creating datarows using `DataRowMetadataOntology.bulk_upsert()`. -* Added support for adding metadata by name when creating datarows using `Dataset.create_data_rows()`, `Dataset.create_data_rows_sync()`, and `Dataset.create_data_row()`. -* Example notebooks for auto metrics in models - -### Changed -* `Dataset.create_data_rows()` max limit of DataRows increased to 150,000 -* Improved error handling for invalid annotation import content -* String metadata can now be 1024 characters long (from 500) - -## Fixed -* Broken urls in detectron notebook - -# Version 3.33.1 (2022-12-14) -### Fixed -* Fixed where batch creation limit was still limiting # of data rows. SDK should now support creating batches with up to 100k data rows - -# Version 3.33.0 (2022-12-13) -### Added -* Added SDK support for creating batches with up to 100k data rows -* Added optional media_type to `client.create_ontology_from_feature_schemas()` and `client.create_ontology()` - -### Changed -* String representation of `DbObject` subclasses are now formatted - -# Version 3.32.0 (2022-12-02) -### Added -* Added `HTML` Enum to `MediaType`. `HTML` is introduced as a new asset type in Labelbox. -* Added `PaginatedCollection.get_one()` and `PaginatedCollection.get_many()` to provide easy functions to fetch single and bulk instances of data for any function returning a `PaginatedCollection`. E.g. `data_rows = dataset.data_rows().get_many(10)` -* Added a validator under `ScalarMetric` to validate metric names against reserved metric names - -### Changed -* In `iou.miou_metric()` and `iou.feature_miou_metric`, iou metric renamed as `custom_iou` - -# Version 3.31.0 (2022-11-28) -### Added -* Added `client.clear_global_keys()` to remove global keys from their associated data rows -* Added a new attribute `confidence` to `AnnotationObject` and `ClassificationAnswer` for Model Error Analysis - -### Fixed -* Fixed `project.create_batch()` to work with both data_row_ids and data_row objects - -# Version 3.30.1 (2022-11-16) -### Added -* Added step to `project.create_batch()` to wait for data rows to finish processing -### Fixed -* Running `project.setup_editor()` multiple times no longer resets the ontology, and instead raises an error if the editor is already set up for the project - -# Version 3.30.0 (2022-11-11) -### Changed -* create_data_rows, create_data_rows_sync, create_data_row, and update data rows all accept the new data row input format for row data -* create_data_row now accepts an attachment parameter to be consistent with create_data_rows -* Conversational text data rows will be uploaded to a json file automatically on the backend to reduce the amount of i/o required in the SDK. - -# Version 3.29.0 (2022-11-02) -### Added -* Added new base `Slice` Entity/DbObject and `CatalogSlice` class -* Added `client.get_catalog_slice(id)` to fetch a CatalogSlice by ID -* Added `slice.get_data_row_ids()` to fetch data row ids of the slice -* Add deprecation warning for queue_mode == QueueMode.Dataset when creating a new project. -* Add deprecation warning for LPOs. - -### Changed -* Default behavior for metrics to not include subclasses in the calculation. - -### Fixed -* Polygon extraction from masks creating invalid polygons. This would cause issues in the coco converter. - -# Version 3.28.0 (2022-10-14) - -### Added -* Added warning for upcoming change in default project queue_mode setting -* Added notebook example for importing Conversational Text annotations using Model-Assisted Labeling - -### Changed -* Updated QueueMode enum to support new value for QueueMode.Batch = `BATCH`. -* Task.failed_data_rows is now a property - -### Fixed -* Fixed Task.wait_till_done() showing warning message for every completed task, instead of only warning when task has errors -* Fixed error on dataset creation step in examples/annotation_import/video.ipynb notebook - -# Version 3.27.2 (2022-10-04) - -### Added -* Added deprecation warning for missing `media_type` in `create_project` in `Client`. - -### Changed -* Updated docs for deprecated methods `_update_queue_mode` and `get_queue_mode` in `Project` - * Use the `queue_mode` attribute in `Project` to get and set the queue mode instead - * For more information, visit https://docs.labelbox.com/reference/migrating-to-workflows#upcoming-changes -* Updated `project.export_labels` to support filtering by start/end time formats "YYYY-MM-DD" and "YYYY-MM-DD hh:mm:ss" - -### Fixed - -# Version 3.27.1 (2022-09-16) -### Changed -* Removed `client.get_data_rows_for_global_keys` until further notice - -# Version 3.27.0 (2022-09-12) -### Added -* Global Keys for data rows - * Assign global keys to a data row with `client.assign_global_keys_to_data_rows` - * Get data rows using global keys with `client.get_data_row_ids_for_global_keys` and `client.get_data_rows_for_global_keys` -* Project Creation - * Introduces `Project.queue_mode` as an optional parameter when creating projects -* `MEAToMALPredictionImport` class - * This allows users to use predictions stored in Models for MAL -* `Task.wait_till_done` now shows a warning if task has failed -### Changed -* Increase scalar metric value limit to 100m -* Added deprecation warnings when updating project `queue_mode` -### Fixed -* Fix bug in `feature_confusion_matrix` and `confusion_matrix` causing FPs and FNs to be capped at 1 when there were no matching annotations - -# Version 3.26.2 (2022-09-06) -### Added -* Support for document (pdf) de/serialization from exports - * Use the `LBV1Converter.serialize()` and `LBV1Converter.deserialize()` methods -* Support for document (pdf) de/serialization for imports - * Use the `NDJsonConverter.serialize()` and `NDJsonConverter.deserialize()` methods - -# Version 3.26.1 (2022-08-23) -### Changed -* `ModelRun.get_config()` - * Modifies get_config to return un-nested Model Run config -### Added -* `ModelRun.update_config()` - * Updates model run training metadata -* `ModelRun.reset_config()` - * Resets model run training metadata -* `ModelRun.get_config()` - * Fetches model run training metadata - -### Changed -* `Model.create_model_run()` - * Add training metadata config as a model run creation param - -# Version 3.26.0 (2022-08-15) -## Added -* `Batch.delete()` which will delete an existing `Batch` -* `Batch.delete_labels()` which will delete all `Label`’s created after a `Project`’s mode has been set to batch. - * Note: Does not include labels that were imported via model-assisted labeling or label imports -* Support for creating model config when creating a model run -* `RAW_TEXT` and `TEXT_FILE` attachment types to replace the `TEXT` type. - -# Version 3.25.3 (2022-08-10) -## Fixed -* Label export will continue polling if the downloadUrl is None - -# Version 3.25.2 (2022-07-26) -## Updated -* Mask downloads now have retries -* Failed `upload_data` now shows more details in the error message - -## Fixed -* Fixed Metadata not importing with DataRows when bulk importing local files. -* Fixed COCOConverter failing for empty annotations - -## Documentation -* Notebooks are up-to-date with examples of importing annotations without `schema_id` - -# Version 3.25.1 (2022-07-20) -## Fixed -* Removed extra dependency causing import errors. - -# Version 3.25.0 (2022-07-20) - -## Added -* Importing annotations with model assisted labeling or label imports using ontology object names instead of schemaId now possible - * In Python dictionaries, you can now use `schemaId` key or `name` key for all tools, classifications, options -* Labelbox's Annotation Types now support model assisted labeling or label imports using ontology object names -* Export metadata when using the following methods: - * `Batch.export_data_rows(include_metadata=True)` - * `Dataset.export_data_rows(include_metadata=True)` - * `Project.export_queued_data_rows(include_metadata=True)` -* `VideoObjectAnnotation` has `segment_index` to group video annotations into video segments - -## Removed -* `Project.video_label_generator`. Use `Project.label_generator` instead. - -## Updated -* Model Runs now support unassigned splits -* `Dataset.create_data_rows` now has the following limits: - * 150,000 rows per upload without metadata - * 30,000 rows per upload with metadata - - -# Version 3.24.1 (2022-07-07) -## Updated -* Added `refresh_ontology()` as part of create/update/delete metadata schema functions - -# Version 3.24.0 (2022-07-06) -## Added -* `DataRowMetadataOntology` class now has functions to create/update/delete metadata schema - * `create_schema` - Create custom metadata schema - * `update_schema` - Update name of custom metadata schema - * `update_enum_options` - Update name of an Enum option for an Enum custom metadata schema - * `delete_schema` - Delete custom metadata schema -* `ModelRun` class now has `assign_data_rows_to_split` function, which can assign a `DataSplit` to a list of `DataRow`s -* `Dataset.create_data_rows()` can bulk import `conversationalData` - -# Version 3.23.3 (2022-06-23) - -## Fix -* Import for `numpy` has been adjusted to work with numpy v.1.23.0 - -# Version 3.23.2 (2022-06-15) -## Added -* `Data Row` object now has a new field, `metadata`, which returns metadata associated with data row as a list of `DataRowMetadataField` - * Note: When importing Data Rows with metadata, use the existing field, `metadata_fields` - -# Version 3.23.1 (2022-06-08) -## Added -* `Task` objects now have the following properties: - * `errors` - fetch information about why the task failed - * `result` - fetch the result of the task - * These are currently only compatible with data row import tasks. -* Officially added support for python 3.9 - -## Removed -* python 3.6 is no longer officially supported - -# Version 3.22.1 (2022-05-23) -## Updated -* Renamed `custom_metadata` to `metadata_fields` in DataRow - -# Version 3.22.0 (2022-05-20) -## Added -* `Dataset.create_data_row()` and `Dataset.create_data_rows()` now uploads metadata to data row -* Added `media_attributes` and `metadata` to `BaseData` - -## Updated -* Removed `iou` from classification metrics - -# Version 3.21.1 (2022-05-12) -## Updated - * `Project.create_batch()` timeout increased to 180 seconds - -# Version 3.21.0 (2022-05-11) -## Added - * Projects can be created with a `media_type` - * Added `media_type` attribute to `Project` - * New `MediaType` enumeration - -## Fix - * Added back the mimetype to datarow bulk uploads for orgs that require delegated access - -# Version 3.20.1 (2022-05-02) -## Updated -* Ontology Classification `scope` field is only set for top level classifications - -# Version 3.20.0 (2022-04-27) -## Added -* Batches in a project can be retrieved with `project.batches()` -* Added `Batch.remove_queued_data_rows()` to cancel remaining data rows in batch -* Added `Batch.export_data_rows()` which returns `DataRow`s for a batch - -## Updated -* NDJsonConverter now supports Video bounding box annotations. - * Note: Currently does not support nested classifications. - * Note: Converting an export into Labelbox annotation types, and back to export will result in only keyframe annotations. This is to support correct import format. - - -## Fix -* `batch.project()` now works - -# Version 3.19.1 (2022-04-14) -## Fix -* `create_data_rows` and `create_data_rows_sync` now uploads the file with a mimetype -* Orgs that only allow DA uploads were getting errors when using these functions - -# Version 3.19.0 (2022-04-12) -## Added -* Added Tool object type RASTER_SEGMENTATION for Video and DICOM ontologies - -# Version 3.18.0 (2022-04-07) -## Added -* Added beta support for exporting labels from model_runs -* LBV1Converter now supports data_split key -* Classification objects now include `scope` key - -## Fix -* Updated notebooks - -# Version 3.17.2 (2022-03-28) -## Fix -* Project.upsert_instructions now works properly for new projects. - -# Version 3.17.1 (2022-03-25) -## Updated -* Remove unused rasterio dependency - -# Version 3.17.0 (2022-03-22) -## Added -* Create batches from the SDK (Beta). Learn more about [batches](https://docs.labelbox.com/docs/batches) -* Support for precision and recall metrics on Entity annotations - -## Fix -* `client.create_project` type hint added for its return type - -## Updated -* Removed batch MVP code - -# Version 3.16.0 (2022-03-08) -## Added -* Ability to fetch a model run with `client.get_model_run()` -* Ability to fetch labels from a model run with `model_run.export_labels()` - - Note: this is only Experimental. To use, client param `enable_experimental` should - be set to true -* Ability to delete an attachment - -## Fix -* Logger level is no longer set to INFO - -## Updated -* Deprecation: Creating Dropdowns will no longer be supported after 2022-03-31 - - This includes creating/adding Dropdowns to an ontology - - This includes creating/adding Dropdown Annotation Type - - For the same functionality, use Radio - - This will not affect existing Dropdowns - -# Changelog -# Version 3.15.0 (2022-02-28) -## Added -* Extras folder which contains useful applications using the sdk -* Addition of ResourceTag at the Organization and Project level -* Updates to the example notebooks - -## Fix -* EPSGTransformer now properly transforms Polygon to Polygon -* VideoData string representation now properly shows VideoData - - -# Version 3.14.0 (2022-02-10) -## Added -* Updated metrics for classifications to be per-answer - - -# Version 3.13.0 (2022-02-07) -## Added -* Added `from_shapely` method to create annotation types from Shapely objects -* Added `start` and `end` filter on the following methods -- `Project.export_labels()` -- `Project.label_generator()` -- `Project.video_label_generator()` -* Improved type hinting - - -# Version 3.12.0 (2022-01-19) -## Added -* Tiled Imagery annotation type -- A set of classes that support Tiled Image assets -- New demo notebook can be found here: examples/annotation_types/tiled_imagery_basics.ipynb -- Updated tiled image mal can be found here: examples/model_assisted_labeling/tiled_imagery_mal.ipynb -* Support transformations from one EPSG to another with `EPSGTransformer` class -- Supports EPSG to Pixel space transformations - - -# Version 3.11.1 (2022-01-10) -## Fix -* Make `TypedArray` class compatible with `numpy` versions `>= 1.22.0` -* `project.upsert_review_queue` quotas can now be in the inclusive range [0,1] -* Restore support for upserting html instructions to a project - -# Version 3.11.0 (2021-12-15) - -## Fix -* `Dataset.create_data_rows()` now accepts an iterable of data row information instead of a list -* `project.upsert_instructions()` - * now only supports pdfs since that is what the editor requires - * There was a bug that could cause this to modify the project ontology - -## Removed -* `DataRowMetadataSchema.id` use `DataRowMetadataSchema.uid` instead -* `ModelRun.delete_annotation_groups()` use `ModelRun.delete_model_run_data_rows()` instead -* `ModelRun.annotation_groups()` use `ModelRun.model_run_data_rows()` instead - -# Version 3.10.0 (2021-11-18) -## Added -* `AnnotationImport.wait_until_done()` accepts a `show_progress` param. This is set to `False` by default. - * If enabled, a tqdm progress bar will indicate the import progress. - * This works for all classes that inherit from AnnotationImport: `LabelImport`, `MALPredictionImport`, `MEAPredictionImport` - * This is not support for `BulkImportRequest` (which will eventually be replaced by `MALPredictionImport`) -* `Option.label` and `Option.value` can now be set independently -* `ClassificationAnswer`s now support a new `keyframe` field for videos -* New `LBV1Label.media_type field. This is a placeholder for future backend changes. - -## Fix -* Nested checklists can have extra brackets. This would cause the annotation type converter to break. - - -# Version 3.9.0 (2021-11-12) -## Added -* New ontology management features - * Query for ontologies by name with `client.get_ontologies()` or by id using `client.get_ontology()` - * Query for feature schemas by name with `client.get_feature_schemas()` or id using `client.get_feature_schema()` - * Create feature schemas with `client.create_feature_schemas()` - * Create ontologies from normalized ontology data with `client.create_ontology()` - * Create ontologies from feature schemas with `client.create_ontology_from_feature_schemas()` - * Set up a project from an existing ontology with `project.setup_edior()` - * Added new `FeatureSchema` entity -* Add support for new queue modes - * Send batches of data directly to a project queue with `project.queue()` - * Remove items from a project queue with `project.dequeue()` - * Query for and toggle the queue mode - -# Version 3.8.0 (2021-10-22) -## Added -* `ModelRun.upsert_data_rows()` - * Add data rows to a model run without also attaching labels -* `OperationNotAllowedException` - * raised when users hit resource limits or are not allowed to use a particular operation - -## Updated -* `ModelRun.upsert_labels()` - * Blocks until the upsert job is complete. Error messages have been improved -* `Organization.invite_user()` and `Organization.invite_limit()` are no longer experimental -* `AnnotationGroup` was renamed to `ModelRunDataRow` -* `ModelRun.delete_annotation_groups()` was renamed to `ModelRun.delete_model_run_data_rows()` -* `ModelRun.annotation_groups()` was renamed to `ModelRun.model_run_data_rows()` - -## Fix -* `DataRowMetadataField` no longer relies on pydantic for field validation and coercion - * This prevents unintended type coercions from occurring - -# Version 3.7.0 (2021-10-11) -## Added -* Search for data row ids from external ids without specifying a dataset - * `client.get_data_row_ids_for_external_ids()` -* Support for numeric metadata type - -## Updated -* The following `DataRowMetadataOntology` fields were renamed: - * `all_fields` -> `fields` - * `all_fields_id_index` -> `fields_by_id` - * `reserved_id_index` -> `reserved_by_id` - * `reserved_name_index` -> `reserved_by_name` - * `custom_id_index` -> `custom_by_id` - * `custom_name_index` -> `custom_by_name` - - -# Version 3.6.1 (2021-10-07) -* Fix import error that appears when exporting labels - -# Version 3.6.0 (2021-10-04) -## Added -* Bulk export metadata with `DataRowMetadataOntology.bulk_export()` -* Add docstring examples of annotation types and a few helper methods - -## Updated -* Update metadata notebook under examples/basics to include bulk_export. -* Allow color to be a single integer when constructing Mask objects -* Allow users to pass int arrays to RasterData and attempt coercion to uint8 - -## Removed -* data_row.metadata was removed in favor of bulk exports. - - -# Version 3.5.0 (2021-09-15) -## Added -* Diagnostics custom metrics - * Metric annotation types - * Update ndjson converter to be compatible with metric types - * Metric library for computing confusion matrix metrics and iou - * Demo notebooks under `examples/diagnostics` -* COCO Converter -* Detectron2 example integration notebooks - -# Version 3.4.1 (2021-09-10) -## Fix -* Iam validation exception message - -# Version 3.4.0 (2021-09-10) -## Added -* New `IAMIntegration` entity -* `Client.create_dataset()` compatibility with delegated access -* `Organization.get_iam_integrations()` to list all integrations available to an org -* `Organization.get_default_iam_integration()` to only get the defaault iam integration - -# Version 3.3.0 (2021-09-02) -## Added -* `Dataset.create_data_rows_sync()` for synchronous bulk uploads of data rows -* `Model.delete()`, `ModelRun.delete()`, and `ModelRun.delete_annotation_groups()` to - Clean up models, model runs, and annotation groups. - -## Fix -* Increased timeout for label exports since projects with many segmentation masks weren't finishing quickly enough. - -# Version 3.2.1 (2021-08-31) -## Fix -* Resolved issue with `create_data_rows()` was not working on amazon linux - -# Version 3.2.0 (2021-08-26) -## Added -* List `BulkImportRequest`s for a project with `Project.bulk_import_requests()` -* Improvemens to `Dataset.create_data_rows()` - * Add attachments when bulk importing data rows - * Provide external ids when creating data rows from local files - * Get more informative error messages when the api rejects an import - -## Fix -* Bug causing `project.label_generator()` to fail when projects had benchmarks - -# Version 3.1.0 (2021-08-18) -## Added -* Support for new HTML attachment type -* Delete Bulk Import Requests with `BulkImportRequest.delete()` - -## Misc -* Updated MEAPredictionImport class to use latest grapqhql endpoints - - -# Version 3.0.1 (2021-08-12) -## Fix -* Issue with inferring text type from export - -# Version 3.0.0 (2021-08-12) -## Added -* Annotation types - - A set of python objects for working with labelbox data - - Creates a standard interface for both exports and imports - - See example notebooks on how to use under examples/annotation_types - - Note that these types are not yet supported for tiled imagery -* MEA Support - - Beta MEA users can now just use the latest SDK release -* Metadata support - - New metadata features are now fully supported by the SDK -* Easier export - - `project.export_labels()` accepts a boolean indicating whether or not to download the result - - Create annotation objects directly from exports with `project.label_generator()` or `project.video_label_generator()` - - `project.video_label_generator()` asynchronously fetches video annotations -* Retry logic on data uploads - - Bulk creation of data rows will be more reliable -* Datasets - - Determine the number of data rows just by calling `dataset.row_count`. - - Updated threading logic in create_data_rows() to make it compatible with aws lambdas -* Ontology - - `OntologyBuilder`, `Classification`, `Option`, and `Tool` can now be imported from `labelbox` instead of `labelbox.schema.ontology` - -## Removed -* Deprecated: - - `project.reviews()` - - `project.create_prediction()` - - `project.create_prediction_model()` - - `project.create_label()` - - `Project.predictions()` - - `Project.active_prediction_model` - - `data_row.predictions` - - `PredictionModel` - - `Prediction` -* Replaced: - - `data_row.metadata()` use `data_row.attachments()` instead - - `data_row.create_metadata()` use `data_row.create_attachments()` instead - - `AssetMetadata` use `AssetAttachment` instead - -## Fixes -* Support derived classes of ontology objects when using `from_dict` -* Notebooks: - - Video export bug where the code would fail if the exported projects had tools other than bounding boxes - - MAL demos were broken due to an image download failing. - -## Misc -* Data processing dependencies are not installed by default to for users that only want client functionality. -* To install all dependencies required for the data modules (annotation types and mea metric calculation) use `pip install labelbox[data]` -* Decrease wait time between updates for `BulkImportRequest.wait_until_done()`. -* Organization is no longer used to create the LFO in `Project.setup()` - - -# Version 3.0.0-rc3 (2021-08-11) -## Updates -* Geometry.raster now has a consistent interface and improved functionality -* renamed schema_id to feature_schema_id in the `FeatureSchema` class -* `Mask` objects now use `MaskData` to represent segmentation masks instead of `ImageData` - -# Version 3.0.0-rc2 (2021-08-09) -## Updates -* Rename `data` property of TextData, ImageData, and VideoData types to `value`. -* Decrease wait time between updates for `BulkImportRequest.wait_until_done()` -* Organization is no longer used to create the LFO in `Project.setup()` - - -# Version 3.0.0-rc1 (2021-08-05) -## Added -* Model diagnostics notebooks -* Minor annotation type improvements - -# Version 3.0.0-rc0 (2021-08-04) -## Added -* Annotation types - - A set of python objects for working with labelbox data - - Creates a standard interface for both exports and imports - - See example notebooks on how to use under examples/annotation_types - - Note that these types are not yet supported for tiled imagery -* MEA Support - - Beta MEA users can now just use the latest SDK release -* Metadata support - - New metadata features are now fully supported by the SDK -* Easier export - - `project.export_labels()` accepts a boolean indicating whether or not to download the result - - Create annotation objects directly from exports with `project.label_generator()` or `project.video_label_generator()` - - `project.video_label_generator()` asynchronously fetches video annotations -* Retry logic on data uploads - - Bulk creation of data rows will be more reliable -* Datasets - - Determine the number of data rows just by calling `dataset.row_count`. - - Updated threading logic in create_data_rows() to make it compatible with aws lambdas -* Ontology - - `OntologyBuilder`, `Classification`, `Option`, and `Tool` can now be imported from `labelbox` instead of `labelbox.schema.ontology` - -## Removed -* Deprecated: - - `project.reviews()` - - `project.create_prediction()` - - `project.create_prediction_model()` - - `project.create_label()` - - `Project.predictions()` - - `Project.active_prediction_model` - - `data_row.predictions` - - `PredictionModel` - - `Prediction` -* Replaced: - - `data_row.metadata()` use `data_row.attachments()` instead - - `data_row.create_metadata()` use `data_row.create_attachments()` instead - - `AssetMetadata` use `AssetAttachment` instead - -## Fixes -* Support derived classes of ontology objects when using `from_dict` -* Notebooks: - - Video export bug where the code would fail if the exported projects had tools other than bounding boxes - - MAL demos were broken due to an image download failing. - -## Misc -* Data processing dependencies are not installed by default to for users that only want client functionality. -* To install all dependencies required for the data modules (annotation types and mea metric calculation) use `pip install labelbox[data]` - -# Version 2.7b1+mea (2021-06-27) -## Fix -* No longer convert `ModelRun.created_by_id` to cuid on construction of a `ModelRun`. - * This was causing queries for ModelRuns to fail. - -# Version 2.7b0+mea (2021-06-27) -## Fix -* Update `AnnotationGroup` to expect labelId to be a cuid instead of uuid. -* Update `datarow_miou` to support masks with multiple classes in them. - -# Version 2.7.0 (2021-06-27) -## Added -* Added `dataset.export_data_rows()` which returns all `DataRows` for a `Dataset`. - -# Version 2.6b2+mea (2021-06-16) -## Added -* `ModelRun.annotation_groups()` to fetch data rows and label information for a model run - -# Version 2.6.0 (2021-06-11) -## Fix -* Upated `create_mask_ndjson` helper function in `image_mal.ipynb` to use the color arg - instead of a hardcoded color. - -## Added -* asset_metadata is now deprecated and has been replaced with asset_attachments - * `AssetAttachment` replaces `AssetMetadata` ( see definition for updated attribute names ) - * Use `DataRow.attachments()` instead of `DataRow.metadata()` - * Use `DataRow.create_attachment()` instead of `DataRow.create_metadata()` -* Update pydantic version - -# Version 2.5b0+mea (2021-06-11) -## Added -* Added new `Model` and 'ModelRun` entities -* Update client to support creating and querying for `Model`s -* Implement new prediction import pipeline to support both MAL and MEA -* Added notebook to demonstrate how to use MEA -* Added `datarow_miou` for calculating datarow level iou scores - - -# Version 2.5.6 (2021-05-19) -## Fix -* MAL validation no longer raises exception when NER tool has same start and end location - -# Version 2.5.5 (2021-05-17) -## Added -* `DataRow` now has a `media_attributes` field -* `DataRow`s can now be looked up from `LabelingParameterOverride`s -* `Project.export_queued_data_rows` to export all data rows in a queue for a project at once - -# Version 2.5.4 (2021-04-22) -## Added -* User management - * Query for remaining invites and users available to an organization - * Set and update organization roles - * Set / update / revoke project role - * Delete users from organization - * Example notebook added under examples/basics -* Issues and comments export - * Bulk export issues and comments. See `Project.export_issues` -* MAL on Tiled Imagery - * Example notebook added under examples/model_assisted_labeling - * `Dataset.create_data_rows` now allows users to upload tms imagery - -# Version 2.5.3 (2021-04-01) -## Added -* Cleanup and add additional example notebooks -* Improved README for SDK and examples -* Easier to retrieve per annotation `BulkImportRequest` status, errors, and inputs - * See `BulkImportRequest.errors`, `BulkImportRequest.statuses`, `BulkImportRequest.inputs` for more information - -# Version 2.5.2 (2021-03-25) -## Fix -* Ontology builder defaults to None for missing fields instead of empty lists -* MAL validation added extra fields to subclasses - -### Added -* Example notebooks - -## Version 2.5.1 (2021-03-15) -### Fix -* `Dataset.data_row_for_external_id` No longer throws `ResourceNotFoundError` when there are duplicates -* Improved doc strings - -### Added -* OntologyBuilder for making project setup easier -* Now supports `IMAGE_OVERLAY` metadata -* Webhooks for review topics added -* Upload project instructions with `Project.upsert_instructions` -* User input validation - * MAL validity is now checked client side for faster feedback - * type and value checks added in a few places - -## Version 2.4.11 (2021-03-07) -### Fix -* Increase query timeout -* Retry 502s - -## Version 2.4.10 (2021-02-05) -### Added -* SDK version added to request headers - -## Version 2.4.9 (2020-11-09) -### Fix -* 2.4.8 was broken for > Python 3.6 -### Added -* include new `Project.enable_model_assisted_labeling` method for turning on [model-assisted labeling](https://labelbox.com/docs/automation/model-assisted-labeling) - -## Version 2.4.8 (2020-11-06) -### Fix -* fix failing `next` call https://github.com/Labelbox/labelbox-python/issues/74 - -## Version 2.4.7 (2020-09-10) -### Added -* `Ontology` schema for interacting with ontologies and their schema nodes - -## Version 2.4.6 (2020-09-03) -### Fix -* fix failing `create_metadata` calls - -## Version 2.4.5 (2020-09-01) -### Added -* retry capabilities for common flaky API failures -* protection against improper types passed into `Project.upload_anntations` -* pass thru API error messages when possible - -## Version 2.4.3 (2020-08-04) - -### Added -* `BulkImportRequest` data type -* `Project.upload_annotation` supports uploading via a local ndjson file, url, or a iterable of annotations - -## Version 2.4.2 (2020-08-01) -### Fixed -* `Client.upload_data` will now pass the correct `content-length` when uploading data. - -## Version 2.4.1 (2020-07-22) -### Fixed -* `Dataset.create_data_row` and `Dataset.create_data_rows` will now upload with content type to ensure the Labelbox editor can show videos. - -## Version 2.4 (2020-01-30) - -### Added -* `Prediction` and `PredictionModel` data types. - -## Version 2.3 (2019-11-12) - -### Changed -* `Client.execute` now automatically extracts the 'data' value from the -returned `dict`. This *breaks* existing code that directly uses the -`Client.execute` method. -* Major code reorganization, naming and test improvements. -* `Label.seconds_to_label` field value is now optional when creating -a `Label`. Default value is 0.0. - -### Added -* `Project.upsert_review_queue` method. -* `Project.extend_reservations` method. -* `Label.created_by` relationship (To-One User). -* Changelog. - -### Fixed -* `Dataset.create_data_row` upload of local file data. - -## Version 2.2 (2019-10-18) -Changelog not maintained before version 2.2. - ----- -README.md -# Labelbox Python SDK -[![Release Notes](https://img.shields.io/github/release/labelbox/labelbox-python)](https://github.com/Labelbox/labelbox-python/releases) -[![CI](https://github.com/labelbox/labelbox-python/actions/workflows/python-package.yml/badge.svg)](https://github.com/labelbox/labelbox-python/actions) -[![Downloads](https://pepy.tech/badge/labelbox)](https://pepy.tech/project/labelbox) -[![Dependency Status](https://img.shields.io/librariesio/github/labelbox/labelbox-python)](https://libraries.io/github/labelbox/labelbox-python) -[![Open Issues](https://img.shields.io/github/issues-raw/labelbox/labelbox-python)](https://github.com/labelbox/labelbox-python/issues) -[![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) -[![Twitter Follow](https://img.shields.io/twitter/follow/labelbox.svg?style=social&label=Follow)](https://twitter.com/labelbox) -[![LinkedIn Follow](https://img.shields.io/badge/Follow-LinkedIn-blue.svg?style=flat&logo=linkedin)](https://www.linkedin.com/company/labelbox/) - - -Labelbox is a cloud-based data-centric AI platform designed to help teams create high-quality training data for their AI models. It provides a suite of tools and features that streamline the process of data curation, labeling, model output evaluation for computer vision and large language models. Visit [Labelbox](http://labelbox.com/) for more information. - - -The Python SDK provides a convenient way to interact with Labelbox programmatically, offering advantages over REST or GraphQL APIs: - -* **Simplified interactions:** The SDK abstracts away the complexities of API calls, making it easier to work with Labelbox. -* **Object-oriented approach:** The SDK provides an object-oriented interface, allowing you to interact with Labelbox entities (projects, datasets, labels, etc.) as Python objects. -* **Extensibility:** The SDK can be extended to support custom data formats and operations. - -## Table of Contents - -- [Labelbox Python SDK](#labelbox-python-sdk) - - [Table of Contents](#table-of-contents) - - [Requirements](#requirements) - - [Installation](#installation) - - [Note for Windows users](#note-for-windows-users) - - [Documentation](#documentation) - - [Authentication](#authentication) - - [Contribution](#contribution) - - [Testing](#testing) - -## Installation -![Python Version](https://img.shields.io/badge/python-3.8%20%7C%203.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue.svg) - -Welcome to the quick start guide for integrating Labelbox into your Python projects. Whether you're looking to incorporate advanced data labeling into your workflow or simply explore the capabilities of the Labelbox Python SDK, this guide will walk you through the two main methods of setting up Labelbox in your environment: via a package manager and by building it locally. - -### Easy Installation with Package Manager - -To get started with the least amount of hassle, follow these simple steps to install the Labelbox Python SDK using pip, Python's package manager. - -1. **Ensure pip is Installed:** First, make sure you have `pip` installed on your system. It's the tool we'll use to install the SDK. - -2. **Sign Up for Labelbox:** If you haven't already, create a free account at [Labelbox](http://app.labelbox.com/) to access its features. - -3. **Generate Your API Key:** Log into Labelbox and navigate to [Account > API Keys](https://docs.labelbox.com/docs/create-an-api-key) to generate an API key. You'll need this for programmatic access to Labelbox. - -4. **Install the SDK:** Open your terminal or command prompt and run the following command to install the Labelbox Python SDK: - - ```bash - pip install labelbox - ``` - -5. **Install Optional Dependencies:** For enhanced functionality, such as data processing, you can install additional dependencies with: - - ```bash - pip install "labelbox[data]" - ``` - - This includes essential libraries like Shapely, GeoJSON, NumPy, Pillow, and OpenCV-Python, enabling you to handle a wide range of data types and perform complex operations. - -### Building and Installing Locally - -For those who prefer or require a more hands-on approach, such as contributing to the SDK or integrating it into a complex project, building the SDK locally is the way to go. - -#### Prerequisites - -- **pip Installation:** Ensure `pip` is installed on your system. For macOS users, you can easily set it up with: - - ```bash - curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py - python3 get-pip.py - ``` - - If you encounter a warning about `pip` not being in your PATH, you'll need to add it manually by modifying your shell configuration (`.zshrc`, `.bashrc`, etc.): - - ```bash - export PATH=/Users//Library/Python/3.8/bin:$PATH - ``` - -#### Steps for Local Installation - -1. **Clone the SDK Repository:** First, clone the Labelbox SDK repository from GitHub to your local machine. - -2. **Install the SDK Locally:** Navigate to the root directory of the cloned repository and run: - - ```bash - pip3 install -e . - ``` - -3. **Install Required Dependencies:** To ensure all dependencies are met, run: - - ```bash - pip3 install -r requirements.txt - ``` - - For additional data processing capabilities, remember to install the `data` extra as mentioned in the easy installation section. - - -## Code Architecture - -The Labelbox Python SDK is designed to be modular and extensible. Key files and classes include: - -- **`labelbox/client.py`:** Contains the `Client` class, which provides methods for interacting with the Labelbox API. -- **`labelbox/orm/model.py`:** Defines the data model for Labelbox entities like projects, datasets, and labels. -- **`labelbox/schema/*.py`:** Contains classes representing specific Labelbox entities and their attributes. -- **`labelbox/data/annotation_types/*.py`:** Defines classes for different annotation types, such as bounding boxes, polygons, and classifications. -- **`labelbox/data/serialization/*.py`:** Provides converters for different data formats, including NDJSON and Labelbox v1 JSON. - -The SDK wraps the GraphQL APIs and provides a Pythonic interface for interacting with Labelbox. - - -## Extending the SDK - -The Labelbox Python SDK is designed to be extensible. Here are examples of how you can extend the SDK: - -### Adding an Export Format Converter - -You can add a new export format converter by creating a class that inherits from the `Converter` class and implements the `convert` method. For example, to add a converter for a custom JSON format: - -```python -class CustomJsonConverter(Converter[CustomJsonOutput]): - - def convert(self, input_args: Converter.ConverterInputArgs) -> Iterator[CustomJsonOutput]: - # Implement logic to convert data to custom JSON format - yield CustomJsonOutput(...) -``` - -## Documentation - -- [Visit our docs](https://docs.labelbox.com/reference) to learn how the SDK works -- Checkout our [notebook examples](examples/) to follow along with interactive tutorials -- view our [API reference](https://labelbox-python.readthedocs.io/en/latest/index.html). - - -## Contribution -Please consult `CONTRIB.md` ----- -CONTRIB.md -# Labelbox Python SDK Contribution Guide - -## Contribution Guidelines -Thank you for expressing your interest in contributing to the Labelbox SDK. -To ensure that your contribution aligns with our guidelines, please carefully -review the following considerations before proceeding: - -* For feature requests, we recommend consulting with Labelbox support or - creating a [Github Issue](https://github.com/Labelbox/labelbox-python/issues) on our repository. -* We can only accept general solutions that address common issues rather than solutions - designed for specific use cases. Acceptable contributions include simple bug fixes and - improvements to functions within the schema/ package. -* Please ensure that any new libraries are compliant with the Apache license that governs the Labelbox SDK. -* Ensure that you update any relevant docstrings and comments within your code -* Ensure that any new python components like classes, methods etc that need to feature in labelbox documentation have entries in the file [index.rst](https://github.com/Labelbox/labelbox-python/blob/develop/docs/source/index.rst). Also make sure you run `make html` locally in the `docs` folder to check if the documentation correctly updated according to the docstrings in the code added. - -## Repository Organization - -The SDK source (excluding tests and support tools) is organized into the -following packages/modules: -* `data/` package contains code that maps annotations (labels or pre-labels) to - Python objects, as well as serialization and deserialization tools for converting - between NDJson and Annotation Types. -* `orm/` package contains code that supports the general mapping of Labelbox - data to Python objects. This includes base classes, attribute (field and - relationship) classes, generic GraphQL queries etc. -* `schema/` package contains definitions of classes which represent data type - (e.g. Project, Label etc.). It relies on `orm/` classes for easy and succinct - object definitions. It also contains custom functionalities and custom GraphQL - templates where necessary. -* `client.py` contains the `Client` class that's the client-side stub for - communicating with Labelbox servers. -* `exceptions.py` contains declarations for all Labelbox errors. -* `pagination.py` contains support for paginated relationship and collection - fetching. -* `utils.py` contains utility functions. - -## Branches - -* All development happens in per-feature branches prefixed by contributor's - initials. For example `fs/feature_name`. -* Approved PRs are merged to the `develop` branch. -* The `develop` branch is merged to `master` on each release. - -## Formatting - -Before making a commit, to automatically adhere to our formatting standards, -install and activate [pre-commit](https://pre-commit.com/) -```shell -pip install pre-commit -pre-commit install -``` -After the above, running `git commit ...` will attempt to fix formatting, -and make necessary changes to files. You will then need to stage those files again. - -You may also manually format your code by running the following: -```shell -yapf tests labelbox -i --verbose --recursive --parallel --style "google" -``` - - -## Testing - -Currently, the SDK functionality is tested using unit and integration tests. -The integration tests communicate with a Labelbox server (by default the staging server) -and are in that sense not self-contained. - -Please consult "Testing" section in the README for more details on how to test. - -Additionally, to execute tests you will need to provide an API key for the server you're using -for testing (staging by default) in the `LABELBOX_TEST_API_KEY` environment -variable. For more info see [Labelbox API key docs](https://labelbox.helpdocs.io/docs/api/getting-started). - - -## Release Steps - -Please consult the Labelbox team for releasing your contributions - -## Running Jupyter Notebooks - -We have plenty of good samples in the _examples_ directory and using them for testing can help us increase our productivity. One way to use jupyter notebooks is to run the jupyter server locally (another way is to use a VSC plugin, not documented here). It works really fast. - -Make sure your notebook will use your source code: -1. `ipython profile create` -2. `ipython locate` - will show where the config file is. This is the config file used by the jupyter server, since it runs via ipython -3. Open the file (this should be ipython_config.py and it is usually located in ~/.ipython/profile_default) and add the following line of code: -``` -c.InteractiveShellApp.exec_lines = [ - 'import sys; sys.path.insert(0, "")' -] -``` -4. Go to the root of your project and run `jupyter notebook` to start the server - - ----- -setup.py -import setuptools - -with open('labelbox/__init__.py') as fid: - for line in fid: - if line.startswith('__version__'): - SDK_VERSION = line.strip().split()[-1][1:-1] - break - -with open("README.md", "r") as fh: - long_description = fh.read() - -setuptools.setup( - name="labelbox", - version=SDK_VERSION, - author="Labelbox", - author_email="engineering@labelbox.com", - description="Labelbox Python API", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://labelbox.com", - packages=setuptools.find_packages(), - install_requires=[ - "requests>=2.22.0", "google-api-core>=1.22.1", "pydantic>=1.8", "tqdm", - "python-dateutil>=2.8.2,<2.9.0" - ], - extras_require={ - 'data': [ - "shapely", "geojson", "numpy", "PILLOW", "opencv-python", - "typeguard", "imagesize", "pyproj", "pygeotile", - "typing-extensions", "packaging" - ], - }, - classifiers=[ - 'License :: OSI Approved :: Apache Software License', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - ], - python_requires='>=3.7', - keywords=["labelbox"], -) - ----- -tests/conftest.py -import glob -from datetime import datetime -from random import randint -from string import ascii_letters - -import pytest - -pytest_plugins = [ - fixture_file.replace("tests/", "").replace("/", ".").replace(".py", "") - for fixture_file in glob.glob( - "tests/integration/annotation_import/fixtures/[!__]*.py",) -] - - -@pytest.fixture(scope="session") -def rand_gen(): - - def gen(field_type): - if field_type is str: - return "".join(ascii_letters[randint(0, - len(ascii_letters) - 1)] - for _ in range(16)) - - if field_type is datetime: - return datetime.now() - - raise Exception("Can't random generate for field type '%r'" % - field_type) - - return gen - ----- -tests/utils.py -def remove_keys_recursive(d, keys): - for k in keys: - if k in d: - del d[k] - for k, v in d.items(): - if isinstance(v, dict): - remove_keys_recursive(v, keys) - elif isinstance(v, list): - for i in v: - if isinstance(i, dict): - remove_keys_recursive(i, keys) - - -# NOTE this uses quite a primitive check for cuids but I do not think it is worth coming up with a better one -# Also this function is NOT written with performance in mind, good for small to mid size dicts like we have in our test -def rename_cuid_key_recursive(d): - new_key = '' - for k in list(d.keys()): - if len(k) == 25 and not k.isalpha(): #primitive check for cuid - d[new_key] = d.pop(k) - for k, v in d.items(): - if isinstance(v, dict): - rename_cuid_key_recursive(v) - elif isinstance(v, list): - for i in v: - if isinstance(i, dict): - rename_cuid_key_recursive(i) - - -INTEGRATION_SNAPSHOT_DIRECTORY = 'tests/integration/snapshots' - ----- -tests/unit/test_ndjson_parsing.py -import ast -from io import StringIO - -from labelbox import parser - - -def test_loads(ndjson_content): - expected_line, expected_objects = ndjson_content - parsed_line = parser.loads(expected_line) - - assert parsed_line == expected_objects - assert parser.dumps(parsed_line) == expected_line - - -def test_loads_bytes(ndjson_content): - expected_line, expected_objects = ndjson_content - - bytes_line = expected_line.encode('utf-8') - parsed_line = parser.loads(bytes_line) - - assert parsed_line == expected_objects - assert parser.dumps(parsed_line) == expected_line - - -def test_reader_stringio(ndjson_content): - line, ndjson_objects = ndjson_content - - text_io = StringIO(line) - parsed_arr = [] - reader = parser.reader(text_io) - for _, r in enumerate(reader): - parsed_arr.append(r) - assert parsed_arr == ndjson_objects - - -def test_non_ascii_new_line(ndjson_content_with_nonascii_and_line_breaks): - line, expected_objects = ndjson_content_with_nonascii_and_line_breaks - parsed = parser.loads(line) - - assert parsed == expected_objects - - # NOTE: json parser converts unicode chars to unicode literals by default and this is a good practice - # but it is not what we want here since we want to compare the strings with actual unicode chars - assert ast.literal_eval("'" + parser.dumps(parsed) + "'") == line - ----- -tests/unit/test_utils.py -import pytest -from labelbox.utils import format_iso_datetime, format_iso_from_string - - -@pytest.mark.parametrize('datetime_str, expected_datetime_str', - [('2011-11-04T00:05:23Z', '2011-11-04T00:05:23Z'), - ('2011-11-04T00:05:23+00:00', '2011-11-04T00:05:23Z'), - ('2011-11-04T00:05:23+05:00', '2011-11-03T19:05:23Z'), - ('2011-11-04T00:05:23', '2011-11-04T00:05:23Z')]) -def test_datetime_parsing(datetime_str, expected_datetime_str): - # NOTE I would normally not take 'expected' using another function from sdk code, but in this case this is exactly the usage in _validate_parse_datetime - assert format_iso_datetime( - format_iso_from_string(datetime_str)) == expected_datetime_str - ----- -tests/unit/conftest.py -import json -import pytest - - -@pytest.fixture -def ndjson_content(): - line = """{"uuid": "9fd9a92e-2560-4e77-81d4-b2e955800092", "schemaId": "ckaeasyfk004y0y7wyye5epgu", "dataRow": {"id": "ck7kftpan8ir008910yf07r9c"}, "bbox": {"top": 48, "left": 58, "height": 865, "width": 1512}} -{"uuid": "29b878f3-c2b4-4dbf-9f22-a795f0720125", "schemaId": "ckapgvrl7007q0y7ujkjkaaxt", "dataRow": {"id": "ck7kftpan8ir008910yf07r9c"}, "polygon": [{"x": 147.692, "y": 118.154}, {"x": 142.769, "y": 404.923}, {"x": 57.846, "y": 318.769}, {"x": 28.308, "y": 169.846}]}""" - expected_objects = [{ - 'uuid': '9fd9a92e-2560-4e77-81d4-b2e955800092', - 'schemaId': 'ckaeasyfk004y0y7wyye5epgu', - 'dataRow': { - 'id': 'ck7kftpan8ir008910yf07r9c' - }, - 'bbox': { - 'top': 48, - 'left': 58, - 'height': 865, - 'width': 1512 - } - }, { - 'uuid': - '29b878f3-c2b4-4dbf-9f22-a795f0720125', - 'schemaId': - 'ckapgvrl7007q0y7ujkjkaaxt', - 'dataRow': { - 'id': 'ck7kftpan8ir008910yf07r9c' - }, - 'polygon': [{ - 'x': 147.692, - 'y': 118.154 - }, { - 'x': 142.769, - 'y': 404.923 - }, { - 'x': 57.846, - 'y': 318.769 - }, { - 'x': 28.308, - 'y': 169.846 - }] - }] - - return line, expected_objects - - -@pytest.fixture -def ndjson_content_with_nonascii_and_line_breaks(): - line = '{"id": "2489651127", "type": "PushEvent", "actor": {"id": 1459915, "login": "xtuaok", "gravatar_id": "", "url": "https://api.github.com/users/xtuaok", "avatar_url": "https://avatars.githubusercontent.com/u/1459915?"}, "repo": {"id": 6719841, "name": "xtuaok/twitter_track_following", "url": "https://api.github.com/repos/xtuaok/twitter_track_following"}, "payload": {"push_id": 536864008, "size": 1, "distinct_size": 1, "ref": "refs/heads/xtuaok", "head": "afb8afe306c7893d93d383a06e4d9df53b41bf47", "before": "4671b4868f1a060f2ed64d8268cd22d514a84e63", "commits": [{"sha": "afb8afe306c7893d93d383a06e4d9df53b41bf47", "author": {"email": "47cb89439b2d6961b59dff4298e837f67aa77389@gmail.com", "name": "Tomonori Tamagawa"}, "message": "Update ID 949438177,, - screen_name: chomado, - name: ちょまど@初詣おみくじ凶, - description: ( *゚▽゚* っ)З腐女子!絵描き!| H26新卒文系SE (入社して4ヶ月目の8月にSIer(適応障害になった)を辞職し開発者に転職) | H26秋応用情報合格!| 自作bot (in PHP) chomado_bot | プログラミングガチ初心者, - location:", "distinct": true, "url": "https://api.github.com/repos/xtuaok/twitter_track_following/commits/afb8afe306c7893d93d383a06e4d9df53b41bf47"}]}, "public": true, "created_at": "2015-01-01T15:00:10Z"}' - expected_objects = [{ - 'id': '2489651127', - 'type': 'PushEvent', - 'actor': { - 'id': 1459915, - 'login': 'xtuaok', - 'gravatar_id': '', - 'url': 'https://api.github.com/users/xtuaok', - 'avatar_url': 'https://avatars.githubusercontent.com/u/1459915?' - }, - 'repo': { - 'id': 6719841, - 'name': 'xtuaok/twitter_track_following', - 'url': 'https://api.github.com/repos/xtuaok/twitter_track_following' - }, - 'payload': { - 'push_id': - 536864008, - 'size': - 1, - 'distinct_size': - 1, - 'ref': - 'refs/heads/xtuaok', - 'head': - 'afb8afe306c7893d93d383a06e4d9df53b41bf47', - 'before': - '4671b4868f1a060f2ed64d8268cd22d514a84e63', - 'commits': [{ - 'sha': - 'afb8afe306c7893d93d383a06e4d9df53b41bf47', - 'author': { - 'email': - '47cb89439b2d6961b59dff4298e837f67aa77389@gmail.com', - 'name': - 'Tomonori Tamagawa' - }, - 'message': - 'Update ID 949438177,, - screen_name: chomado, - name: ちょまど@初詣おみくじ凶, - description: ( *゚▽゚* っ)З腐女子!絵描き!| H26新卒文系SE (入社して4ヶ月目の8月にSIer(適応障害になった)を辞職し開発者に転職) | H26秋応用情報合格!| 自作bot (in PHP) chomado_bot | プログラミングガチ初心者, - location:', - 'distinct': - True, - 'url': - 'https://api.github.com/repos/xtuaok/twitter_track_following/commits/afb8afe306c7893d93d383a06e4d9df53b41bf47' - }] - }, - 'public': True, - 'created_at': '2015-01-01T15:00:10Z' - }] - return line, expected_objects - - -@pytest.fixture -def generate_random_ndjson(rand_gen): - - def _generate_random_ndjson(lines: int = 10): - return [ - json.dumps({"data_row": { - "id": rand_gen(str) - }}) for _ in range(lines) - ] - - return _generate_random_ndjson - - -@pytest.fixture -def mock_response(): - - class MockResponse: - - def __init__(self, text: str, exception: Exception = None) -> None: - self._text = text - self._exception = exception - - @property - def text(self): - return self._text - - def raise_for_status(self): - if self._exception: - raise self._exception - - return MockResponse - ----- -tests/unit/test_unit_filter.py -import pytest - -from labelbox import Project -from labelbox.orm.comparison import Comparison, LogicalExpression - - -def test_comparison_creation(): - comparison = Comparison.Op.EQ(Project.name, "test") - assert comparison.op == Comparison.Op.EQ - assert comparison.field == Project.name - assert comparison.value == "test" - - -def test_comparison_equality(): - Op = Comparison.Op - assert Op.EQ(Project.name, "test") == Op.EQ(Project.name, "test") - assert Op.EQ(Project.name, "test") != Op.EQ(Project.uid, "test") - assert Op.EQ(Project.name, "test") != Op.EQ(Project.name, "t") - assert Op.EQ(Project.name, "test") != Op.NE(Project.name, "test") - - -def test_rich_comparison(): - Op = Comparison.Op - assert (Project.uid == "uid") == Op.EQ(Project.uid, "uid") - assert (Project.uid != "uid") == Op.NE(Project.uid, "uid") - assert (Project.uid < "uid") == Op.LT(Project.uid, "uid") - assert (Project.uid <= "uid") == Op.LE(Project.uid, "uid") - assert (Project.uid > "uid") == Op.GT(Project.uid, "uid") - assert (Project.uid >= "uid") == Op.GE(Project.uid, "uid") - - # inverse operands - assert ("uid" == Project.uid) == Op.EQ(Project.uid, "uid") - assert ("uid" != Project.uid) == Op.NE(Project.uid, "uid") - assert ("uid" < Project.uid) == Op.GT(Project.uid, "uid") - assert ("uid" <= Project.uid) == Op.GE(Project.uid, "uid") - assert ("uid" > Project.uid) == Op.LT(Project.uid, "uid") - assert ("uid" >= Project.uid) == Op.LE(Project.uid, "uid") - - -def test_logical_expr_creation(): - comparison_1 = Comparison.Op.EQ(Project.name, "name") - comparison_2 = Comparison.Op.LT(Project.uid, "uid") - - op = LogicalExpression.Op.AND(comparison_1, comparison_2) - assert op.op == LogicalExpression.Op.AND - assert op.first == comparison_1 - assert op.second == comparison_2 - - -def test_logical_expr_ops(): - comparison_1 = Comparison.Op.EQ(Project.name, "name") - comparison_2 = Comparison.Op.LT(Project.uid, "uid") - - log_op_1 = comparison_1 & comparison_2 - assert log_op_1 == LogicalExpression.Op.AND(comparison_1, comparison_2) - log_op_2 = log_op_1 | comparison_1 - assert log_op_2 == LogicalExpression.Op.OR(log_op_1, comparison_1) - log_op_3 = ~log_op_2 - assert log_op_3 == LogicalExpression.Op.NOT(log_op_2) - - # Can't create logical expressions with anything except comparisons and - # other logical expressions. - with pytest.raises(TypeError): - logical_op = comparison_1 & 42 - with pytest.raises(TypeError): - logical_op = comparison_1 | 42 - ----- -tests/unit/test_unit_delete_batch_data_row_metadata.py -from re import U - -from labelbox.schema.data_row_metadata import _DeleteBatchDataRowMetadata -from labelbox.schema.identifiable import GlobalKey, UniqueId - - -def test_dict_delete_data_row_batch(): - obj = _DeleteBatchDataRowMetadata( - data_row_identifier=UniqueId("abcd"), - schema_ids=["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"]) - assert obj.dict() == { - "data_row_identifier": { - "id": "abcd", - "id_type": "ID" - }, - "schema_ids": [ - "clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy" - ] - } - - obj = _DeleteBatchDataRowMetadata( - data_row_identifier=GlobalKey("fegh"), - schema_ids=["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"]) - assert obj.dict() == { - "data_row_identifier": { - "id": "fegh", - "id_type": "GKEY" - }, - "schema_ids": [ - "clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy" - ] - } - - -def test_dict_delete_data_row_batch_by_alias(): - obj = _DeleteBatchDataRowMetadata( - data_row_identifier=UniqueId("abcd"), - schema_ids=["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"]) - assert obj.dict(by_alias=True) == { - "dataRowIdentifier": { - "id": "abcd", - "idType": "ID" - }, - "schemaIds": ["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"] - } - - obj = _DeleteBatchDataRowMetadata( - data_row_identifier=GlobalKey("fegh"), - schema_ids=["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"]) - assert obj.dict(by_alias=True) == { - "dataRowIdentifier": { - "id": "fegh", - "idType": "GKEY" - }, - "schemaIds": ["clqh77tyk000008l2a9mjesa1", "clqh784br000008jy0yuq04fy"] - } - ----- -tests/unit/test_unit_webhook.py -from unittest.mock import MagicMock -import pytest - -from labelbox import Webhook - - -def test_webhook_create_with_no_secret(rand_gen): - client = MagicMock() - project = MagicMock() - secret = "" - url = "https:/" + rand_gen(str) - topics = [] - - with pytest.raises(ValueError) as exc_info: - Webhook.create(client, topics, url, secret, project) - assert str(exc_info.value) == \ - "Secret must be a non-empty string." - - -def test_webhook_create_with_no_topics(rand_gen): - client = MagicMock() - project = MagicMock() - secret = rand_gen(str) - url = "https:/" + rand_gen(str) - topics = [] - - with pytest.raises(ValueError) as exc_info: - Webhook.create(client, topics, url, secret, project) - assert str(exc_info.value) == \ - "Topics must be a non-empty list." - - -def test_webhook_create_with_no_url(rand_gen): - client = MagicMock() - project = MagicMock() - secret = rand_gen(str) - url = "" - topics = [Webhook.Topic.LABEL_CREATED, Webhook.Topic.LABEL_DELETED] - - with pytest.raises(ValueError) as exc_info: - Webhook.create(client, topics, url, secret, project) - assert str(exc_info.value) == \ - "URL must be a non-empty string." - ----- -tests/unit/test_unit_project_validate_labeling_parameter_overrides.py -import pytest -from unittest.mock import MagicMock - -from labelbox.schema.data_row import DataRow -from labelbox.schema.identifiable import GlobalKey, UniqueId -from labelbox.schema.project import validate_labeling_parameter_overrides - - -def test_validate_labeling_parameter_overrides_valid_data(): - mock_data_row = MagicMock(spec=DataRow) - mock_data_row.uid = "abc" - data = [(mock_data_row, 1), (UniqueId("efg"), 2), (GlobalKey("hij"), 3)] - validate_labeling_parameter_overrides(data) - - -def test_validate_labeling_parameter_overrides_invalid_data(): - data = [("abc", 1), (UniqueId("efg"), 2), (GlobalKey("hij"), 3)] - with pytest.raises(TypeError): - validate_labeling_parameter_overrides(data) - - -def test_validate_labeling_parameter_overrides_invalid_priority(): - mock_data_row = MagicMock(spec=DataRow) - mock_data_row.uid = "abc" - data = [(mock_data_row, "invalid"), (UniqueId("efg"), 2), - (GlobalKey("hij"), 3)] - with pytest.raises(TypeError): - validate_labeling_parameter_overrides(data) - - -def test_validate_labeling_parameter_overrides_invalid_tuple_length(): - mock_data_row = MagicMock(spec=DataRow) - mock_data_row.uid = "abc" - data = [(mock_data_row, "invalid"), (UniqueId("efg"), 2), - (GlobalKey("hij"))] - with pytest.raises(TypeError): - validate_labeling_parameter_overrides(data) - ----- -tests/unit/test_unit_label_import.py -import uuid -import pytest -from unittest.mock import MagicMock, patch - -from labelbox.schema.annotation_import import LabelImport, logger - - -def test_should_warn_user_about_unsupported_confidence(): - """this test should check running state only to validate running, not completed""" - id = str(uuid.uuid4()) - - labels = [ - { - "uuid": "b862c586-8614-483c-b5e6-82810f70cac0", - "schemaId": "ckrazcueb16og0z6609jj7y3y", - "dataRow": { - "id": "ckrazctum0z8a0ybc0b0o0g0v" - }, - "confidence": 0.851, - "bbox": { - "top": 1352, - "left": 2275, - "height": 350, - "width": 139 - } - }, - ] - with patch.object(LabelImport, '_create_label_import_from_bytes'): - with patch.object(logger, 'warning') as warning_mock: - LabelImport.create_from_objects(client=MagicMock(), - project_id=id, - name=id, - labels=labels) - warning_mock.assert_called_once() - "Confidence scores are not supported in Label Import" in warning_mock.call_args_list[ - 0].args[0] - - -def test_invalid_labels_format(): - """this test should confirm that labels are required to be in a form of list""" - id = str(uuid.uuid4()) - - label = { - "uuid": "b862c586-8614-483c-b5e6-82810f70cac0", - "schemaId": "ckrazcueb16og0z6609jj7y3y", - "dataRow": { - "id": "ckrazctum0z8a0ybc0b0o0g0v" - }, - "bbox": { - "top": 1352, - "left": 2275, - "height": 350, - "width": 139 - } - } - with patch.object(LabelImport, '_create_label_import_from_bytes'): - with pytest.raises(TypeError): - LabelImport.create_from_objects(client=MagicMock(), - project_id=id, - name=id, - labels=label) - ----- -tests/unit/test_unit_case_change.py -from labelbox import utils - -SNAKE = "this_is_a_string" -TITLE = "ThisIsAString" -CAMEL = "thisIsAString" -MIXED = "this_Is_AString" - - -def test_camel(): - assert utils.camel_case(SNAKE) == CAMEL - assert utils.camel_case(TITLE) == CAMEL - assert utils.camel_case(CAMEL) == CAMEL - assert utils.camel_case(MIXED) == CAMEL - - -def test_snake(): - assert utils.snake_case(SNAKE) == SNAKE - assert utils.snake_case(TITLE) == SNAKE - assert utils.snake_case(CAMEL) == SNAKE - assert utils.snake_case(MIXED) == SNAKE - - -def test_title(): - assert utils.title_case(SNAKE) == TITLE - assert utils.title_case(TITLE) == TITLE - assert utils.title_case(CAMEL) == TITLE - assert utils.title_case(MIXED) == TITLE - ----- -tests/unit/test_queue_mode.py -import pytest - -from labelbox.schema.queue_mode import QueueMode - - -def test_parse_deprecated_catalog(): - assert QueueMode("CATALOG") == QueueMode.Batch - - -def test_parse_batch(): - assert QueueMode("BATCH") == QueueMode.Batch - - -def test_parse_data_set(): - assert QueueMode("DATA_SET") == QueueMode.Dataset - - -def test_fails_for_unknown(): - with pytest.raises(ValueError): - QueueMode("foo") - ----- -tests/unit/test_unit_query.py -import pytest - -from labelbox import Project, Dataset -from labelbox.orm import query -from labelbox.orm.comparison import Comparison, LogicalExpression - - -def format(*args, **kwargs): - return query.Query(*args, **kwargs).format()[0] - - -def test_query_what(): - assert format("first", Project).startswith("first{") - assert format("other", Project).startswith("other{") - - -def test_query_subquery(): - assert format("x", query.Query("sub", Project)).startswith("x{sub{") - assert format("x", query.Query("bus", Project)).startswith("x{bus{") - - -def test_query_where(): - q, p = query.Query("x", Project, Project.name > "name").format() - assert q.startswith("x(where: {name_gt: $param_0}){") - assert p == {"param_0": ("name", Project.name)} - - q, p = query.Query("x", Project, - (Project.name != "name") & (Project.uid <= 42)).format() - assert q.startswith( - "x(where: {AND: [{name_not: $param_0}, {id_lte: $param_1}]}") - assert p == { - "param_0": ("name", Project.name), - "param_1": (42, Project.uid) - } - - -def test_query_param_declaration(): - q, _ = query.Query("x", Project, Project.name > "name").format_top("y") - assert q.startswith("query yPyApi($param_0: String!){x") - - q, _ = query.Query("x", Project, (Project.name > "name") & - (Project.uid == 42)).format_top("y") - assert q.startswith("query yPyApi($param_0: String!, $param_1: ID!){x") - - -def test_query_order_by(): - q, _ = query.Query("x", Project, order_by=Project.name.asc).format() - assert q.startswith("x(orderBy: name_ASC){") - - q, _ = query.Query("x", Project, order_by=Project.uid.desc).format() - assert q.startswith("x(orderBy: id_DESC){") - - -def test_logical_ops(): - Op = LogicalExpression.Op - - assert list(query.logical_ops(None)) == [] - comparison_1 = Comparison.Op.EQ(Project.name, "name") - assert list(query.logical_ops(comparison_1)) == [] - comparison_2 = Comparison.Op.LT(Dataset.uid, "uid") - assert list(query.logical_ops(comparison_2)) == [] - op_1 = Op.AND(comparison_1, comparison_2) - assert list(query.logical_ops(op_1)) == [Op.AND] - op_2 = Op.OR(comparison_1, comparison_2) - assert list(query.logical_ops(op_2)) == [Op.OR] - op_3 = Op.OR(op_1, op_2) - assert list(query.logical_ops(op_3)) == [Op.OR, Op.AND, Op.OR] - ----- -tests/unit/test_unit_ontology.py -import pytest - -from labelbox.exceptions import InconsistentOntologyException -from labelbox import Tool, Classification, Option, OntologyBuilder - -_SAMPLE_ONTOLOGY = { - "tools": [{ - "schemaNodeId": None, - "featureSchemaId": None, - "required": False, - "name": "poly", - "color": "#FF0000", - "tool": "polygon", - "classifications": [] - }, { - "schemaNodeId": None, - "featureSchemaId": None, - "required": False, - "name": "segment", - "color": "#FF0000", - "tool": "superpixel", - "classifications": [] - }, { - "schemaNodeId": - None, - "featureSchemaId": - None, - "required": - False, - "name": - "bbox", - "color": - "#FF0000", - "tool": - "rectangle", - "classifications": [{ - "schemaNodeId": - None, - "featureSchemaId": - None, - "required": - True, - "instructions": - "nested classification", - "name": - "nested classification", - "type": - "radio", - "options": [{ - "schemaNodeId": - None, - "featureSchemaId": - None, - "label": - "first", - "value": - "first", - "options": [{ - "schemaNodeId": None, - "featureSchemaId": None, - "required": False, - "instructions": "nested nested text", - "name": "nested nested text", - "type": "text", - "options": [] - }] - }, { - "schemaNodeId": None, - "featureSchemaId": None, - "label": "second", - "value": "second", - "options": [] - }] - }, { - "schemaNodeId": None, - "featureSchemaId": None, - "required": True, - "instructions": "nested text", - "name": "nested text", - "type": "text", - "options": [] - }] - }, { - "schemaNodeId": None, - "featureSchemaId": None, - "required": False, - "name": "dot", - "color": "#FF0000", - "tool": "point", - "classifications": [] - }, { - "schemaNodeId": None, - "featureSchemaId": None, - "required": False, - "name": "polyline", - "color": "#FF0000", - "tool": "line", - "classifications": [] - }, { - "schemaNodeId": None, - "featureSchemaId": None, - "required": False, - "name": "ner", - "color": "#FF0000", - "tool": "named-entity", - "classifications": [] - }], - "classifications": [{ - "schemaNodeId": - None, - "featureSchemaId": - None, - "required": - True, - "instructions": - "This is a question.", - "name": - "This is a question.", - "type": - "radio", - "scope": - "global", - "options": [{ - "schemaNodeId": None, - "featureSchemaId": None, - "label": "yes", - "value": "definitely yes", - "options": [] - }, { - "schemaNodeId": None, - "featureSchemaId": None, - "label": "no", - "value": "definitely not", - "options": [] - }] - }] -} - - -@pytest.mark.parametrize("tool_type", list(Tool.Type)) -def test_create_tool(tool_type) -> None: - t = Tool(tool=tool_type, name="tool") - assert (t.tool == tool_type) - - -@pytest.mark.parametrize("class_type", list(Classification.Type)) -def test_create_classification(class_type) -> None: - c = Classification(class_type=class_type, name="classification") - assert (c.class_type == class_type) - - -@pytest.mark.parametrize("value, expected_value, typing", - [(3, 3, int), ("string", "string", str)]) -def test_create_option_with_value(value, expected_value, typing) -> None: - o = Option(value=value) - assert (o.value == expected_value) - assert (o.value == o.label) - - -@pytest.mark.parametrize("value, label, expected_value, typing", - [(3, 2, 3, int), - ("string", "another string", "string", str)]) -def test_create_option_with_value_and_label(value, label, expected_value, - typing) -> None: - o = Option(value=value, label=label) - assert (o.value == expected_value) - assert o.value != o.label - assert isinstance(o.value, typing) - - -def test_create_empty_ontology() -> None: - o = OntologyBuilder() - assert (o.tools == []) - assert (o.classifications == []) - - -def test_add_ontology_tool() -> None: - o = OntologyBuilder() - o.add_tool(Tool(tool=Tool.Type.BBOX, name="bounding box")) - - second_tool = Tool(tool=Tool.Type.SEGMENTATION, name="segmentation") - o.add_tool(second_tool) - assert len(o.tools) == 2 - - for tool in o.tools: - assert (type(tool) == Tool) - - with pytest.raises(InconsistentOntologyException) as exc: - o.add_tool(Tool(tool=Tool.Type.BBOX, name="bounding box")) - assert "Duplicate tool name" in str(exc.value) - - -def test_add_ontology_classification() -> None: - o = OntologyBuilder() - o.add_classification( - Classification(class_type=Classification.Type.TEXT, name="text")) - - second_classification = Classification( - class_type=Classification.Type.CHECKLIST, name="checklist") - o.add_classification(second_classification) - assert len(o.classifications) == 2 - - for classification in o.classifications: - assert (type(classification) == Classification) - - with pytest.raises(InconsistentOntologyException) as exc: - o.add_classification( - Classification(class_type=Classification.Type.TEXT, name="text")) - assert "Duplicate classification name" in str(exc.value) - - -def test_tool_add_classification() -> None: - t = Tool(tool=Tool.Type.SEGMENTATION, name="segmentation") - c = Classification(class_type=Classification.Type.TEXT, name="text") - t.add_classification(c) - assert t.classifications == [c] - - with pytest.raises(Exception) as exc: - t.add_classification(c) - assert "Duplicate nested classification" in str(exc) - - -def test_classification_add_option() -> None: - c = Classification(class_type=Classification.Type.RADIO, name="radio") - o = Option(value="option") - c.add_option(o) - assert c.options == [o] - - with pytest.raises(InconsistentOntologyException) as exc: - c.add_option(Option(value="option")) - assert "Duplicate option" in str(exc.value) - - -def test_option_add_option() -> None: - o = Option(value="option") - c = Classification(class_type=Classification.Type.TEXT, name="text") - o.add_option(c) - assert o.options == [c] - - with pytest.raises(InconsistentOntologyException) as exc: - o.add_option(c) - assert "Duplicate nested classification" in str(exc.value) - - -def test_ontology_asdict() -> None: - assert OntologyBuilder.from_dict( - _SAMPLE_ONTOLOGY).asdict() == _SAMPLE_ONTOLOGY - - -def test_classification_using_instructions_instead_of_name_shows_warning(): - with pytest.warns(Warning): - Classification(class_type=Classification.Type.TEXT, instructions="text") - - -def test_classification_without_name_raises_error(): - with pytest.raises(ValueError): - Classification(class_type=Classification.Type.TEXT) - ----- -tests/unit/test_unit_entity_meta.py -import pytest - -from labelbox.orm.model import Relationship -from labelbox.orm.db_object import DbObject - - -def test_illegal_cache_cond1(): - - class TestEntityA(DbObject): - test_entity_b = Relationship.ToOne("TestEntityB", cache=True) - - with pytest.raises(TypeError) as exc_info: - - class TestEntityB(DbObject): - another_entity = Relationship.ToOne("AnotherEntity", cache=True) - - assert "`test_entity_a` caches `test_entity_b` which caches `['another_entity']`" in str( - exc_info.value) - - -def test_illegal_cache_cond2(): - - class TestEntityD(DbObject): - another_entity = Relationship.ToOne("AnotherEntity", cache=True) - - with pytest.raises(TypeError) as exc_info: - - class TestEntityC(DbObject): - test_entity_d = Relationship.ToOne("TestEntityD", cache=True) - - assert "`test_entity_c` caches `test_entity_d` which caches `['another_entity']`" in str( - exc_info.value) - ----- -tests/unit/test_unit_identifiables.py -from labelbox.schema.identifiables import GlobalKeys, UniqueIds - - -def test_unique_ids(): - ids = ["a", "b", "c"] - identifiables = UniqueIds(ids) - assert [i for i in identifiables] == ids - assert identifiables.id_type == "ID" - assert len(identifiables) == 3 - - -def test_global_keys(): - ids = ["a", "b", "c"] - identifiables = GlobalKeys(ids) - assert [i for i in identifiables] == ids - assert identifiables.id_type == "GKEY" - assert len(identifiables) == 3 - - -def test_index_access(): - ids = ["a", "b", "c"] - identifiables = GlobalKeys(ids) - assert identifiables[0] == "a" - assert identifiables[1:3] == GlobalKeys(["b", "c"]) - - -def test_repr(): - ids = ["a", "b", "c"] - identifiables = GlobalKeys(ids) - assert repr(identifiables) == "GlobalKeys(['a', 'b', 'c'])" - - ids = ["a", "b", "c"] - identifiables = UniqueIds(ids) - assert repr(identifiables) == "UniqueIds(['a', 'b', 'c'])" - ----- -tests/unit/test_annotation_import.py -import pytest - -from labelbox.schema.annotation_import import AnnotationImport - - -def test_data_row_validation_errors(): - predictions = [ - { - "answer": { - "schemaId": "ckrb1sfl8099g0y91cxbd5ftb", - }, - "schemaId": "c123", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - }, - { - "answer": { - "schemaId": "ckrb1sfl8099g0y91cxbd5ftb", - }, - "schemaId": "c123", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - }, - { - "answer": { - "schemaId": "ckrb1sfl8099g0y91cxbd5ftb", - }, - "schemaId": "c123", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - }, - { - "answer": { - "schemaId": "ckrb1sfl8099g0y91cxbd5ftb", - }, - "schemaId": "c123", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - }, - { - "answer": { - "schemaId": "ckrb1sfl8099g0y91cxbd5ftb", - }, - "schemaId": "c123", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - }, - ] - - # Set up data for validation errors - # Invalid: Remove 'dataRow' part entirely - del predictions[0]['dataRow'] - - # Invalid: Set both id and globalKey - predictions[1]['dataRow'] = { - 'id': 'some id', - 'globalKey': 'some global key' - } - - # Invalid: Set both id and globalKey to None - predictions[2]['dataRow'] = {'id': None, 'globalKey': None} - - # Valid - predictions[3]['dataRow'] = { - 'id': 'some id', - } - - # Valid - predictions[4]['dataRow'] = { - 'globalKey': 'some global key', - } - - with pytest.raises(ValueError) as exc_info: - AnnotationImport._validate_data_rows(predictions) - exception_str = str(exc_info.value) - assert "Found 3 annotations with errors" in exception_str - assert "'dataRow' is missing in" in exception_str - assert "Must provide only one of 'id' or 'globalKey' for 'dataRow'" in exception_str - assert "'dataRow': {'id': 'some id', 'globalKey': 'some global key'}" in exception_str - assert "'dataRow': {'id': None, 'globalKey': None}" in exception_str - ----- -tests/unit/test_mal_import.py -import uuid -import pytest -from unittest.mock import MagicMock, patch - -from labelbox.schema.annotation_import import MALPredictionImport, logger - - -def test_should_warn_user_about_unsupported_confidence(): - """this test should check running state only to validate running, not completed""" - id = str(uuid.uuid4()) - - labels = [ - { - "bbox": { - "height": 428, - "left": 2089, - "top": 1251, - "width": 158 - }, - "classifications": [{ - "answer": [{ - "schemaId": "ckrb1sfl8099e0y919v260awv", - "confidence": 0.894 - }], - "schemaId": "ckrb1sfkn099c0y910wbo0p1a" - }], - "dataRow": { - "id": "ckrb1sf1i1g7i0ybcdc6oc8ct" - }, - "schemaId": "ckrb1sfjx099a0y914hl319ie", - "uuid": "d009925d-91a3-4f67-abd9-753453f5a584" - }, - ] - with patch.object(MALPredictionImport, '_create_mal_import_from_bytes'): - with patch.object(logger, 'warning') as warning_mock: - MALPredictionImport.create_from_objects(client=MagicMock(), - project_id=id, - name=id, - predictions=labels) - warning_mock.assert_called_once() - "Confidence scores are not supported in MAL Prediction Import" in warning_mock.call_args_list[ - 0].args[0] - - -def test_invalid_labels_format(): - """this test should confirm that annotations are required to be in a form of list""" - id = str(uuid.uuid4()) - - label = { - "bbox": { - "height": 428, - "left": 2089, - "top": 1251, - "width": 158 - }, - "classifications": [{ - "answer": [{ - "schemaId": "ckrb1sfl8099e0y919v260awv", - "confidence": 0.894 - }], - "schemaId": "ckrb1sfkn099c0y910wbo0p1a" - }], - "dataRow": { - "id": "ckrb1sf1i1g7i0ybcdc6oc8ct" - }, - "schemaId": "ckrb1sfjx099a0y914hl319ie", - "uuid": "3a83db52-75e0-49af-a171-234ce604502a" - } - - with patch.object(MALPredictionImport, '_create_mal_import_from_bytes'): - with pytest.raises(TypeError): - MALPredictionImport.create_from_objects(client=MagicMock(), - project_id=id, - name=id, - predictions=label) - ----- -tests/unit/test_unit_export_filters.py -from unittest.mock import MagicMock - -import pytest - -from labelbox.schema.export_filters import build_filters - - -def test_ids_filter(): - client = MagicMock() - filters = {"data_row_ids": ["id1", "id2"], "batch_ids": ["b1", "b2"]} - assert build_filters(client, filters) == [{ - "ids": ["id1", "id2"], - "operator": "is", - "type": "data_row_id", - }, { - "ids": ["b1", "b2"], - "operator": "is", - "type": "batch", - }] - - -def test_ids_empty_filter(): - client = MagicMock() - filters = {"data_row_ids": [], "batch_ids": ["b1", "b2"]} - with pytest.raises(ValueError, - match="data_row_id filter expects a non-empty list."): - build_filters(client, filters) - - -def test_global_keys_filter(): - client = MagicMock() - filters = {"global_keys": ["id1", "id2"]} - assert build_filters(client, filters) == [{ - "ids": ["id1", "id2"], - "operator": "is", - "type": "global_key", - }] - - -def test_validations(): - client = MagicMock() - filters = { - "global_keys": ["id1", "id2"], - "data_row_ids": ["id1", "id2"], - } - with pytest.raises( - ValueError, - match= - "data_rows and global_keys cannot both be present in export filters" - ): - build_filters(client, filters) - ----- -tests/unit/test_unit_rand_gen.py -def test_gen_str(rand_gen): - assert len({rand_gen(str) for _ in range(100)}) == 100 - ----- -tests/unit/export_task/test_unit_file_retriever_by_line.py -from unittest.mock import MagicMock, patch -from labelbox.schema.export_task import ( - FileRetrieverByLine, - _TaskContext, - _MetadataHeader, - StreamType, -) - - -class TestFileRetrieverByLine: - - def test_by_line_from_start(self, generate_random_ndjson, mock_response): - line_count = 10 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) + "\n" - - mock_client = MagicMock() - mock_client.execute = MagicMock( - return_value={ - "task": { - "exportFileFromLine": { - "offsets": { - "start": "0", - "end": len(file_content) - 1 - }, - "lines": { - "start": "0", - "end": str(line_count - 1) - }, - "file": "http://some-url.com/file.ndjson", - } - } - }) - - mock_ctx = _TaskContext( - client=mock_client, - task_id="task-id", - stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), - ) - - with patch("requests.get", return_value=mock_response(file_content)): - retriever = FileRetrieverByLine(mock_ctx, 0) - info, content = retriever.get_next_chunk() - assert info.offsets.start == 0 - assert info.offsets.end == len(file_content) - 1 - assert info.lines.start == 0 - assert info.lines.end == line_count - 1 - assert info.file == "http://some-url.com/file.ndjson" - assert content == file_content - - def test_by_line_from_middle(self, generate_random_ndjson, mock_response): - line_count = 10 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) + "\n" - - mock_client = MagicMock() - mock_client.execute = MagicMock( - return_value={ - "task": { - "exportFileFromLine": { - "offsets": { - "start": "0", - "end": len(file_content) - 1 - }, - "lines": { - "start": "0", - "end": str(line_count - 1) - }, - "file": "http://some-url.com/file.ndjson", - } - } - }) - - mock_ctx = _TaskContext( - client=mock_client, - task_id="task-id", - stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), - ) - - line_start = 5 - current_offset = file_content.find(ndjson[line_start]) - - with patch("requests.get", return_value=mock_response(file_content)): - retriever = FileRetrieverByLine(mock_ctx, line_start) - info, content = retriever.get_next_chunk() - assert info.offsets.start == current_offset - assert info.offsets.end == len(file_content) - 1 - assert info.lines.start == line_start - assert info.lines.end == line_count - 1 - assert info.file == "http://some-url.com/file.ndjson" - assert content == file_content[current_offset:] - - def test_by_line_from_last(self, generate_random_ndjson, mock_response): - line_count = 10 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) + "\n" - - mock_client = MagicMock() - mock_client.execute = MagicMock( - return_value={ - "task": { - "exportFileFromLine": { - "offsets": { - "start": "0", - "end": len(file_content) - 1 - }, - "lines": { - "start": "0", - "end": str(line_count - 1) - }, - "file": "http://some-url.com/file.ndjson", - } - } - }) - - mock_ctx = _TaskContext( - client=mock_client, - task_id="task-id", - stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), - ) - - line_start = 9 - current_offset = file_content.find(ndjson[line_start]) - - with patch("requests.get", return_value=mock_response(file_content)): - retriever = FileRetrieverByLine(mock_ctx, line_start) - info, content = retriever.get_next_chunk() - assert info.offsets.start == current_offset - assert info.offsets.end == len(file_content) - 1 - assert info.lines.start == line_start - assert info.lines.end == line_count - 1 - assert info.file == "http://some-url.com/file.ndjson" - assert content == file_content[current_offset:] - ----- -tests/unit/export_task/test_unit_file_retriever_by_offset.py -from unittest.mock import MagicMock, patch -from labelbox.schema.export_task import ( - FileRetrieverByOffset, - _TaskContext, - _MetadataHeader, - StreamType, -) - - -class TestFileRetrieverByOffset: - - def test_by_offset_from_start(self, generate_random_ndjson, mock_response): - line_count = 10 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) + "\n" - - mock_client = MagicMock() - mock_client.execute = MagicMock( - return_value={ - "task": { - "exportFileFromOffset": { - "offsets": { - "start": "0", - "end": len(file_content) - 1 - }, - "lines": { - "start": "0", - "end": str(line_count - 1) - }, - "file": "http://some-url.com/file.ndjson", - } - } - }) - - mock_ctx = _TaskContext( - client=mock_client, - task_id="task-id", - stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), - ) - - with patch("requests.get", return_value=mock_response(file_content)): - retriever = FileRetrieverByOffset(mock_ctx, 0) - info, content = retriever.get_next_chunk() - assert info.offsets.start == 0 - assert info.offsets.end == len(file_content) - 1 - assert info.lines.start == 0 - assert info.lines.end == line_count - 1 - assert info.file == "http://some-url.com/file.ndjson" - assert content == file_content - - def test_by_offset_from_middle(self, generate_random_ndjson, mock_response): - line_count = 10 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) + "\n" - - mock_client = MagicMock() - mock_client.execute = MagicMock( - return_value={ - "task": { - "exportFileFromOffset": { - "offsets": { - "start": "0", - "end": len(file_content) - 1 - }, - "lines": { - "start": "0", - "end": str(line_count - 1) - }, - "file": "http://some-url.com/file.ndjson", - } - } - }) - - mock_ctx = _TaskContext( - client=mock_client, - task_id="task-id", - stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), - ) - - line_start = 5 - skipped_bytes = 15 - current_offset = file_content.find(ndjson[line_start]) + skipped_bytes - - with patch("requests.get", return_value=mock_response(file_content)): - retriever = FileRetrieverByOffset(mock_ctx, current_offset) - info, content = retriever.get_next_chunk() - assert info.offsets.start == current_offset - assert info.offsets.end == len(file_content) - 1 - assert info.lines.start == 5 - assert info.lines.end == line_count - 1 - assert info.file == "http://some-url.com/file.ndjson" - assert content == file_content[current_offset:] - ----- -tests/unit/export_task/test_unit_file_converter.py -from unittest.mock import MagicMock - -from labelbox.schema.export_task import ( - Converter, - FileConverter, - Range, - StreamType, - _MetadataFileInfo, - _MetadataHeader, - _TaskContext, -) - - -class TestFileConverter: - - def test_with_correct_ndjson(self, tmp_path, generate_random_ndjson): - directory = tmp_path / "file-converter" - directory.mkdir() - line_count = 10 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) + "\n" - input_args = Converter.ConverterInputArgs( - ctx=_TaskContext( - client=MagicMock(), - task_id="task-id", - stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), - ), - file_info=_MetadataFileInfo( - offsets=Range(start=0, end=len(file_content) - 1), - lines=Range(start=0, end=line_count - 1), - file="file.ndjson", - ), - raw_data=file_content, - ) - path = directory / "output.ndjson" - with FileConverter(file_path=path) as converter: - for output in converter.convert(input_args): - assert output.current_line == 0 - assert output.current_offset == 0 - assert output.file_path == path - assert output.total_lines == line_count - assert output.total_size == len(file_content) - assert output.bytes_written == len(file_content) - - def test_with_no_newline_at_end(self, tmp_path, generate_random_ndjson): - directory = tmp_path / "file-converter" - directory.mkdir() - line_count = 10 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) - input_args = Converter.ConverterInputArgs( - ctx=_TaskContext( - client=MagicMock(), - task_id="task-id", - stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader(total_size=len(file_content), - total_lines=line_count), - ), - file_info=_MetadataFileInfo( - offsets=Range(start=0, end=len(file_content) - 1), - lines=Range(start=0, end=line_count - 1), - file="file.ndjson", - ), - raw_data=file_content, - ) - path = directory / "output.ndjson" - with FileConverter(file_path=path) as converter: - for output in converter.convert(input_args): - assert output.current_line == 0 - assert output.current_offset == 0 - assert output.file_path == path - assert output.total_lines == line_count - assert output.total_size == len(file_content) - assert output.bytes_written == len(file_content) - ----- -tests/unit/export_task/test_unit_json_converter.py -from unittest.mock import MagicMock - -from labelbox.schema.export_task import Converter, JsonConverter, Range, _MetadataFileInfo - - -class TestJsonConverter: - - def test_with_correct_ndjson(self, generate_random_ndjson): - line_count = 10 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) + "\n" - input_args = Converter.ConverterInputArgs( - ctx=MagicMock(), - file_info=_MetadataFileInfo( - offsets=Range(start=0, end=len(file_content) - 1), - lines=Range(start=0, end=line_count - 1), - file="file.ndjson", - ), - raw_data=file_content, - ) - with JsonConverter() as converter: - current_offset = 0 - for idx, output in enumerate(converter.convert(input_args)): - assert output.current_line == idx - assert output.current_offset == current_offset - assert output.json_str == ndjson[idx] - current_offset += len(output.json_str) + 1 - - def test_with_no_newline_at_end(self, generate_random_ndjson): - line_count = 10 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) - input_args = Converter.ConverterInputArgs( - ctx=MagicMock(), - file_info=_MetadataFileInfo( - offsets=Range(start=0, end=len(file_content) - 1), - lines=Range(start=0, end=line_count - 1), - file="file.ndjson", - ), - raw_data=file_content, - ) - with JsonConverter() as converter: - current_offset = 0 - for idx, output in enumerate(converter.convert(input_args)): - assert output.current_line == idx - assert output.current_offset == current_offset - assert output.json_str == ndjson[idx] - current_offset += len(output.json_str) + 1 - - def test_from_offset(self, generate_random_ndjson): - # testing middle of a JSON string, but not the last line - line_count = 10 - line_start = 5 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) + "\n" - offset_end = len(file_content) - skipped_bytes = 15 - current_offset = file_content.find(ndjson[line_start]) + skipped_bytes - file_content = file_content[current_offset:] - - input_args = Converter.ConverterInputArgs( - ctx=MagicMock(), - file_info=_MetadataFileInfo( - offsets=Range(start=current_offset, end=offset_end), - lines=Range(start=line_start, end=line_count - 1), - file="file.ndjson", - ), - raw_data=file_content, - ) - with JsonConverter() as converter: - for idx, output in enumerate(converter.convert(input_args)): - assert output.current_line == line_start + idx - assert output.current_offset == current_offset - assert output.json_str == ndjson[line_start + - idx][skipped_bytes:] - current_offset += len(output.json_str) + 1 - skipped_bytes = 0 - - def test_from_offset_last_line(self, generate_random_ndjson): - # testing middle of a JSON string, but not the last line - line_count = 10 - line_start = 9 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) + "\n" - offset_end = len(file_content) - skipped_bytes = 15 - current_offset = file_content.find(ndjson[line_start]) + skipped_bytes - file_content = file_content[current_offset:] - - input_args = Converter.ConverterInputArgs( - ctx=MagicMock(), - file_info=_MetadataFileInfo( - offsets=Range(start=current_offset, end=offset_end), - lines=Range(start=line_start, end=line_count - 1), - file="file.ndjson", - ), - raw_data=file_content, - ) - with JsonConverter() as converter: - for idx, output in enumerate(converter.convert(input_args)): - assert output.current_line == line_start + idx - assert output.current_offset == current_offset - assert output.json_str == ndjson[line_start + - idx][skipped_bytes:] - current_offset += len(output.json_str) + 1 - skipped_bytes = 0 - ----- -tests/integration/test_labeling_frontend.py -from labelbox import LabelingFrontend - - -def test_get_labeling_frontends(client): - filtered_frontends = list( - client.get_labeling_frontends(where=LabelingFrontend.name == 'Editor')) - assert len(filtered_frontends) - - -def test_labeling_frontend_connecting_to_project(project): - assert project.labeling_frontend() == None - - frontend = list(project.client.get_labeling_frontends())[0] - - project.labeling_frontend.connect(frontend) - assert project.labeling_frontend() == frontend - - project.labeling_frontend.disconnect(frontend) - assert project.labeling_frontend() == None - ----- -tests/integration/test_pagination.py -from copy import copy -import time - -import pytest - -from labelbox.schema.dataset import Dataset - - -@pytest.fixture -def data_for_dataset_order_test(client, rand_gen): - name = rand_gen(str) - dataset1 = client.create_dataset(name=name) - dataset2 = client.create_dataset(name=name) - - yield name - - dataset1.delete() - dataset2.delete() - - -def test_get_one_and_many_dataset_order(client, data_for_dataset_order_test): - name = data_for_dataset_order_test - - paginator = client.get_datasets(where=Dataset.name == name) - # confirm get_one returns first dataset - all_datasets = list(paginator) - assert len(all_datasets) == 2 - get_one_dataset = copy(paginator).get_one() - assert get_one_dataset.uid == all_datasets[0].uid - - # confirm get_many(1) returns first dataset - get_many_datasets = copy(paginator).get_many(1) - assert get_many_datasets[0].uid == all_datasets[0].uid - ----- -tests/integration/conftest.py -from collections import defaultdict -from itertools import islice -import json -import os -import sys -import time -import uuid -from types import SimpleNamespace -from typing import Type, List - -import pytest -import requests - -from labelbox import Dataset, DataRow -from labelbox import LabelingFrontend -from labelbox import OntologyBuilder, Tool, Option, Classification, MediaType -from labelbox.orm import query -from labelbox.pagination import PaginatedCollection -from labelbox.schema.annotation_import import LabelImport -from labelbox.schema.catalog import Catalog -from labelbox.schema.enums import AnnotationImportState -from labelbox.schema.invite import Invite -from labelbox.schema.quality_mode import QualityMode -from labelbox.schema.queue_mode import QueueMode -from labelbox.schema.user import User -from support.integration_client import Environ, IntegrationClient, EphemeralClient, AdminClient - -IMG_URL = "https://picsum.photos/200/300.jpg" -MASKABLE_IMG_URL = "https://storage.googleapis.com/labelbox-datasets/image_sample_data/2560px-Kitano_Street_Kobe01s5s4110.jpeg" -SMALL_DATASET_URL = "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/potato.jpeg" -DATA_ROW_PROCESSING_WAIT_TIMEOUT_SECONDS = 30 -DATA_ROW_PROCESSING_WAIT_SLEEP_INTERNAL_SECONDS = 3 - - -@pytest.fixture(scope="session") -def environ() -> Environ: - """ - Checks environment variables for LABELBOX_ENVIRON to be - 'prod' or 'staging' - - Make sure to set LABELBOX_TEST_ENVIRON in .github/workflows/python-package.yaml - - """ - try: - return Environ(os.environ['LABELBOX_TEST_ENVIRON']) - except KeyError: - raise Exception(f'Missing LABELBOX_TEST_ENVIRON in: {os.environ}') - - -def cancel_invite(client, invite_id): - """ - Do not use. Only for testing. - """ - query_str = """mutation CancelInvitePyApi($where: WhereUniqueIdInput!) { - cancelInvite(where: $where) {id}}""" - client.execute(query_str, {'where': {'id': invite_id}}, experimental=True) - - -def get_project_invites(client, project_id): - """ - Do not use. Only for testing. - """ - id_param = "projectId" - query_str = """query GetProjectInvitationsPyApi($from: ID, $first: PageSize, $%s: ID!) { - project(where: {id: $%s}) {id - invites(from: $from, first: $first) { nodes { %s - projectInvites { projectId projectRoleName } } nextCursor}}} - """ % (id_param, id_param, query.results_query_part(Invite)) - return PaginatedCollection(client, - query_str, {id_param: project_id}, - ['project', 'invites', 'nodes'], - Invite, - cursor_path=['project', 'invites', 'nextCursor']) - - -def get_invites(client): - """ - Do not use. Only for testing. - """ - query_str = """query GetOrgInvitationsPyApi($from: ID, $first: PageSize) { - organization { id invites(from: $from, first: $first) { - nodes { id createdAt organizationRoleName inviteeEmail } nextCursor }}}""" - invites = PaginatedCollection( - client, - query_str, {}, ['organization', 'invites', 'nodes'], - Invite, - cursor_path=['organization', 'invites', 'nextCursor'], - experimental=True) - return invites - - -@pytest.fixture -def queries(): - return SimpleNamespace(cancel_invite=cancel_invite, - get_project_invites=get_project_invites, - get_invites=get_invites) - - -@pytest.fixture(scope="session") -def admin_client(environ: str): - return AdminClient(environ) - - -@pytest.fixture(scope="session") -def client(environ: str): - if environ == Environ.EPHEMERAL: - return EphemeralClient() - return IntegrationClient(environ) - - -@pytest.fixture(scope="session") -def image_url(client): - return client.upload_data(requests.get(MASKABLE_IMG_URL).content, - content_type="image/jpeg", - filename="image.jpeg", - sign=True) - - -@pytest.fixture(scope="session") -def pdf_url(client): - pdf_url = client.upload_file('tests/assets/loremipsum.pdf') - return {"row_data": {"pdf_url": pdf_url,}, "global_key": str(uuid.uuid4())} - - -@pytest.fixture(scope="session") -def pdf_entity_data_row(client): - pdf_url = client.upload_file( - 'tests/assets/arxiv-pdf_data_99-word-token-pdfs_0801.3483.pdf') - text_layer_url = client.upload_file( - 'tests/assets/arxiv-pdf_data_99-word-token-pdfs_0801.3483-lb-textlayer.json' - ) - - return { - "row_data": { - "pdf_url": pdf_url, - "text_layer_url": text_layer_url - }, - "global_key": str(uuid.uuid4()) - } - - -@pytest.fixture() -def conversation_entity_data_row(client, rand_gen): - return { - "row_data": - "https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json", - "global_key": - f"https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json-{rand_gen(str)}", - } - - -@pytest.fixture -def project(client, rand_gen): - project = client.create_project(name=rand_gen(str), - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) - yield project - project.delete() - - -@pytest.fixture -def consensus_project(client, rand_gen): - project = client.create_project(name=rand_gen(str), - quality_mode=QualityMode.Consensus, - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) - yield project - project.delete() - - -@pytest.fixture -def consensus_project_with_batch(consensus_project, initial_dataset, rand_gen, - image_url): - project = consensus_project - dataset = initial_dataset - - data_rows = [] - for _ in range(3): - data_rows.append({ - DataRow.row_data: image_url, - DataRow.global_key: str(uuid.uuid4()) - }) - task = dataset.create_data_rows(data_rows) - task.wait_till_done() - assert task.status == "COMPLETE" - - data_rows = list(dataset.data_rows()) - assert len(data_rows) == 3 - batch = project.create_batch( - rand_gen(str), - data_rows, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - - yield [project, batch, data_rows] - batch.delete() - - -@pytest.fixture -def dataset(client, rand_gen): - dataset = client.create_dataset(name=rand_gen(str)) - yield dataset - dataset.delete() - - -@pytest.fixture(scope='function') -def unique_dataset(client, rand_gen): - dataset = client.create_dataset(name=rand_gen(str)) - yield dataset - dataset.delete() - - -@pytest.fixture -def small_dataset(dataset: Dataset): - task = dataset.create_data_rows([ - { - "row_data": SMALL_DATASET_URL, - "external_id": "my-image" - }, - ] * 2) - task.wait_till_done() - - yield dataset - - -@pytest.fixture -def data_row(dataset, image_url, rand_gen): - global_key = f"global-key-{rand_gen(str)}" - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image", - "global_key": global_key - }, - ]) - task.wait_till_done() - dr = dataset.data_rows().get_one() - yield dr - dr.delete() - - -@pytest.fixture -def data_row_and_global_key(dataset, image_url, rand_gen): - global_key = f"global-key-{rand_gen(str)}" - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image", - "global_key": global_key - }, - ]) - task.wait_till_done() - dr = dataset.data_rows().get_one() - yield dr, global_key - dr.delete() - - -# can be used with -# @pytest.mark.parametrize('data_rows', [], indirect=True) -# if omitted, count defaults to 1 -@pytest.fixture -def data_rows(dataset, image_url, request, wait_for_data_row_processing, - client): - count = 1 - if hasattr(request, 'param'): - count = request.param - - datarows = [ - dict(row_data=image_url, global_key=f"global-key-{uuid.uuid4()}") - for _ in range(count) - ] - - task = dataset.create_data_rows(datarows) - task.wait_till_done() - datarows = dataset.data_rows().get_many(count) - for dr in dataset.data_rows(): - wait_for_data_row_processing(client, dr) - - yield datarows - - for datarow in datarows: - datarow.delete() - - -@pytest.fixture -def iframe_url(environ) -> str: - if environ in [Environ.PROD, Environ.LOCAL]: - return 'https://editor.labelbox.com' - elif environ == Environ.STAGING: - return 'https://editor.lb-stage.xyz' - - -@pytest.fixture -def sample_image() -> str: - path_to_video = 'tests/integration/media/sample_image.jpg' - return path_to_video - - -@pytest.fixture -def sample_video() -> str: - path_to_video = 'tests/integration/media/cat.mp4' - return path_to_video - - -@pytest.fixture -def sample_bulk_conversation() -> list: - path_to_conversation = 'tests/integration/media/bulk_conversation.json' - with open(path_to_conversation) as json_file: - conversations = json.load(json_file) - return conversations - - -@pytest.fixture -def organization(client): - # Must have at least one seat open in your org to run these tests - org = client.get_organization() - # Clean up before and after incase this wasn't run for some reason. - for invite in get_invites(client): - if "@labelbox.com" in invite.email: - cancel_invite(client, invite.uid) - yield org - for invite in get_invites(client): - if "@labelbox.com" in invite.email: - cancel_invite(client, invite.uid) - - -@pytest.fixture -def project_based_user(client, rand_gen): - email = rand_gen(str) - # Use old mutation because it doesn't require users to accept email invites - query_str = """mutation MakeNewUserPyApi { - addMembersToOrganization( - data: { - emails: ["%s@labelbox.com"], - orgRoleId: "%s", - projectRoles: [] - } - ) { - newUserId - } - } - """ % (email, str(client.get_roles()['NONE'].uid)) - user_id = client.execute( - query_str)['addMembersToOrganization'][0]['newUserId'] - assert user_id is not None, "Unable to add user with old mutation" - user = client._get_single(User, user_id) - yield user - client.get_organization().remove_user(user) - - -@pytest.fixture -def project_pack(client): - projects = [ - client.create_project(name=f"user-proj-{idx}", - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) for idx in range(2) - ] - yield projects - for proj in projects: - proj.delete() - - -@pytest.fixture -def initial_dataset(client, rand_gen): - dataset = client.create_dataset(name=rand_gen(str)) - yield dataset - - dataset.delete() - - -@pytest.fixture -def project_with_empty_ontology(project): - editor = list( - project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - empty_ontology = {"tools": [], "classifications": []} - project.setup(editor, empty_ontology) - yield project - - -@pytest.fixture -def configured_project(project_with_empty_ontology, initial_dataset, rand_gen, - image_url): - dataset = initial_dataset - data_row_id = dataset.create_data_row(row_data=image_url).uid - project = project_with_empty_ontology - - batch = project.create_batch( - rand_gen(str), - [data_row_id], # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - project.data_row_ids = [data_row_id] - - yield project - - batch.delete() - - -@pytest.fixture -def configured_project_with_label(client, rand_gen, image_url, project, dataset, - data_row, wait_for_label_processing): - """Project with a connected dataset, having one datarow - Project contains an ontology with 1 bbox tool - Additionally includes a create_label method for any needed extra labels - One label is already created and yielded when using fixture - """ - project._wait_until_data_rows_are_processed( - data_row_ids=[data_row.uid], - wait_processing_max_seconds=DATA_ROW_PROCESSING_WAIT_TIMEOUT_SECONDS, - sleep_interval=DATA_ROW_PROCESSING_WAIT_SLEEP_INTERNAL_SECONDS) - - project.create_batch( - rand_gen(str), - [data_row.uid], # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - ontology = _setup_ontology(project) - label = _create_label(project, data_row, ontology, - wait_for_label_processing) - yield [project, dataset, data_row, label] - - for label in project.labels(): - label.delete() - - -@pytest.fixture -def configured_batch_project_with_label(project, dataset, data_row, - wait_for_label_processing): - """Project with a batch having one datarow - Project contains an ontology with 1 bbox tool - Additionally includes a create_label method for any needed extra labels - One label is already created and yielded when using fixture - """ - data_rows = [dr.uid for dr in list(dataset.data_rows())] - project._wait_until_data_rows_are_processed(data_row_ids=data_rows, - sleep_interval=3) - project.create_batch("test-batch", data_rows) - project.data_row_ids = data_rows - - ontology = _setup_ontology(project) - label = _create_label(project, data_row, ontology, - wait_for_label_processing) - - yield [project, dataset, data_row, label] - - for label in project.labels(): - label.delete() - - -@pytest.fixture -def configured_batch_project_with_multiple_datarows(project, dataset, data_rows, - wait_for_label_processing): - """Project with a batch having multiple datarows - Project contains an ontology with 1 bbox tool - Additionally includes a create_label method for any needed extra labels - """ - global_keys = [dr.global_key for dr in data_rows] - - batch_name = f'batch {uuid.uuid4()}' - project.create_batch(batch_name, global_keys=global_keys) - - ontology = _setup_ontology(project) - for datarow in data_rows: - _create_label(project, datarow, ontology, wait_for_label_processing) - - yield [project, dataset, data_rows] - - for label in project.labels(): - label.delete() - - -def _create_label(project, data_row, ontology, wait_for_label_processing): - predictions = [{ - "uuid": str(uuid.uuid4()), - "schemaId": ontology.tools[0].feature_schema_id, - "dataRow": { - "id": data_row.uid - }, - "bbox": { - "top": 20, - "left": 20, - "height": 50, - "width": 50 - } - }] - - def create_label(): - """ Ad-hoc function to create a LabelImport - Creates a LabelImport task which will create a label - """ - upload_task = LabelImport.create_from_objects( - project.client, project.uid, f'label-import-{uuid.uuid4()}', - predictions) - upload_task.wait_until_done(sleep_time_seconds=5) - assert upload_task.state == AnnotationImportState.FINISHED, "Label Import did not finish" - assert len( - upload_task.errors - ) == 0, f"Label Import {upload_task.name} failed with errors {upload_task.errors}" - - project.create_label = create_label - project.create_label() - label = wait_for_label_processing(project)[0] - return label - - -def _setup_ontology(project): - editor = list( - project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - ontology_builder = OntologyBuilder(tools=[ - Tool(tool=Tool.Type.BBOX, name="test-bbox-class"), - ]) - project.setup(editor, ontology_builder.asdict()) - # TODO: ontology may not be synchronous after setup. remove sleep when api is more consistent - time.sleep(2) - return OntologyBuilder.from_project(project) - - -@pytest.fixture -def configured_project_with_complex_ontology(client, initial_dataset, rand_gen, - image_url): - project = client.create_project(name=rand_gen(str), - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) - dataset = initial_dataset - data_row = dataset.create_data_row(row_data=image_url) - data_row_ids = [data_row.uid] - - project.create_batch( - rand_gen(str), - data_row_ids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - project.data_row_ids = data_row_ids - - editor = list( - project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - - ontology = OntologyBuilder() - tools = [ - Tool(tool=Tool.Type.BBOX, name="test-bbox-class"), - Tool(tool=Tool.Type.LINE, name="test-line-class"), - Tool(tool=Tool.Type.POINT, name="test-point-class"), - Tool(tool=Tool.Type.POLYGON, name="test-polygon-class"), - Tool(tool=Tool.Type.NER, name="test-ner-class") - ] - - options = [ - Option(value="first option answer"), - Option(value="second option answer"), - Option(value="third option answer") - ] - - classifications = [ - Classification(class_type=Classification.Type.TEXT, - name="test-text-class"), - Classification(class_type=Classification.Type.DROPDOWN, - name="test-dropdown-class", - options=options), - Classification(class_type=Classification.Type.RADIO, - name="test-radio-class", - options=options), - Classification(class_type=Classification.Type.CHECKLIST, - name="test-checklist-class", - options=options) - ] - - for t in tools: - for c in classifications: - t.add_classification(c) - ontology.add_tool(t) - for c in classifications: - ontology.add_classification(c) - - project.setup(editor, ontology.asdict()) - - yield [project, data_row] - project.delete() - - -# NOTE this is nice heuristics, also there is this logic _wait_until_data_rows_are_processed in Project -# in case we still have flakiness in the future, we can use it -@pytest.fixture -def wait_for_data_row_processing(): - """ - Do not use. Only for testing. - - Returns DataRow after waiting for it to finish processing media_attributes. - Some tests, specifically ones that rely on label export, rely on - DataRow be fully processed with media_attributes - """ - - def func(client, data_row, compare_with_prev_media_attrs=False): - """ - added check_updated_at because when a data_row is updated from say - an image to pdf, it already has media_attributes and the loop does - not wait for processing to a pdf - """ - prev_media_attrs = data_row.media_attributes if compare_with_prev_media_attrs else None - data_row_id = data_row.uid - timeout_seconds = 60 - while True: - data_row = client.get_data_row(data_row_id) - if data_row.media_attributes and (prev_media_attrs is None or - prev_media_attrs - != data_row.media_attributes): - return data_row - timeout_seconds -= 2 - if timeout_seconds <= 0: - raise TimeoutError( - f"Timed out waiting for DataRow '{data_row_id}' to finish processing media_attributes" - ) - time.sleep(2) - - return func - - -@pytest.fixture -def wait_for_label_processing(): - """ - Do not use. Only for testing. - - Returns project's labels as a list after waiting for them to finish processing. - If `project.labels()` is called before label is fully processed, - it may return an empty set - """ - - def func(project): - timeout_seconds = 10 - while True: - labels = list(project.labels()) - if len(labels) > 0: - return labels - timeout_seconds -= 2 - if timeout_seconds <= 0: - raise TimeoutError( - f"Timed out waiting for label for project '{project.uid}' to finish processing" - ) - time.sleep(2) - - return func - - -@pytest.fixture -def ontology(client): - ontology_builder = OntologyBuilder( - tools=[ - Tool(tool=Tool.Type.BBOX, name="Box 1", color="#ff0000"), - Tool(tool=Tool.Type.BBOX, name="Box 2", color="#ff0000") - ], - classifications=[ - Classification(name="Root Class", - class_type=Classification.Type.RADIO, - options=[ - Option(value="1", label="Option 1"), - Option(value="2", label="Option 2") - ]) - ]) - ontology = client.create_ontology('Integration Test Ontology', - ontology_builder.asdict(), - MediaType.Image) - yield ontology - client.delete_unused_ontology(ontology.uid) - - -@pytest.fixture -def video_data(client, rand_gen, video_data_row, wait_for_data_row_processing): - dataset = client.create_dataset(name=rand_gen(str)) - data_row_ids = [] - data_row = dataset.create_data_row(video_data_row) - data_row = wait_for_data_row_processing(client, data_row) - data_row_ids.append(data_row.uid) - yield dataset, data_row_ids - dataset.delete() - - -def create_video_data_row(rand_gen): - return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4", - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/video-sample-data/sample-video-1.mp4-{rand_gen(str)}", - "media_type": - "VIDEO", - } - - -@pytest.fixture -def video_data_100_rows(client, rand_gen, wait_for_data_row_processing): - dataset = client.create_dataset(name=rand_gen(str)) - data_row_ids = [] - for _ in range(100): - data_row = dataset.create_data_row(create_video_data_row(rand_gen)) - data_row = wait_for_data_row_processing(client, data_row) - data_row_ids.append(data_row.uid) - yield dataset, data_row_ids - dataset.delete() - - -@pytest.fixture() -def video_data_row(rand_gen): - return create_video_data_row(rand_gen) - - -class ExportV2Helpers: - - @classmethod - def run_project_export_v2_task(cls, - project, - num_retries=5, - task_name=None, - filters={}, - params={}): - task = None - params = params if params else { - "project_details": True, - "performance_details": False, - "data_row_details": True, - "label_details": True - } - while (num_retries > 0): - task = project.export_v2(task_name=task_name, - filters=filters, - params=params) - task.wait_till_done() - assert task.status == "COMPLETE" - assert task.errors is None - if len(task.result) == 0: - num_retries -= 1 - time.sleep(5) - else: - break - return task.result - - @classmethod - def run_dataset_export_v2_task(cls, - dataset, - num_retries=5, - task_name=None, - filters={}, - params={}): - task = None - params = params if params else { - "performance_details": False, - "label_details": True - } - while (num_retries > 0): - task = dataset.export_v2(task_name=task_name, - filters=filters, - params=params) - task.wait_till_done() - assert task.status == "COMPLETE" - assert task.errors is None - if len(task.result) == 0: - num_retries -= 1 - time.sleep(5) - else: - break - - return task.result - - @classmethod - def run_catalog_export_v2_task(cls, - client, - num_retries=5, - task_name=None, - filters={}, - params={}): - task = None - params = params if params else { - "performance_details": False, - "label_details": True - } - catalog = client.get_catalog() - while (num_retries > 0): - - task = catalog.export_v2(task_name=task_name, - filters=filters, - params=params) - task.wait_till_done() - assert task.status == "COMPLETE" - assert task.errors is None - if len(task.result) == 0: - num_retries -= 1 - time.sleep(5) - else: - break - - return task.result - - -@pytest.fixture -def export_v2_test_helpers() -> Type[ExportV2Helpers]: - return ExportV2Helpers() - - -IMAGE_URL = "https://storage.googleapis.com/diagnostics-demo-data/coco/COCO_train2014_000000000034.jpg" -EXTERNAL_ID = "my-image" - - -@pytest.fixture -def big_dataset(dataset: Dataset): - task = dataset.create_data_rows([ - { - "row_data": IMAGE_URL, - "external_id": EXTERNAL_ID - }, - ] * 3) - task.wait_till_done() - - yield dataset - - -@pytest.fixture -def big_dataset_data_row_ids(big_dataset: Dataset) -> List[str]: - yield [dr.uid for dr in list(big_dataset.export_data_rows())] - - -@pytest.fixture(scope='function') -def dataset_with_invalid_data_rows(unique_dataset: Dataset): - upload_invalid_data_rows_for_dataset(unique_dataset) - - yield unique_dataset - - -def upload_invalid_data_rows_for_dataset(dataset: Dataset): - task = dataset.create_data_rows([ - { - "row_data": 'gs://invalid-bucket/example.png', # forbidden - "external_id": "image-without-access.jpg" - }, - ] * 2) - task.wait_till_done() - - -def pytest_configure(): - pytest.report = defaultdict(int) - - -@pytest.hookimpl(hookwrapper=True) -def pytest_fixture_setup(fixturedef): - start = time.time() - yield - end = time.time() - - exec_time = end - start - if "FIXTURE_PROFILE" in os.environ: - pytest.report[fixturedef.argname] += exec_time - - -@pytest.fixture(scope='session', autouse=True) -def print_perf_summary(): - yield - - if "FIXTURE_PROFILE" in os.environ: - sorted_dict = dict( - sorted(pytest.report.items(), - key=lambda item: item[1], - reverse=True)) - num_of_entries = 10 if len(sorted_dict) >= 10 else len(sorted_dict) - slowest_fixtures = [(aaa, sorted_dict[aaa]) - for aaa in islice(sorted_dict, num_of_entries)] - print("\nTop slowest fixtures:\n", slowest_fixtures, file=sys.stderr) - ----- -tests/integration/test_webhook.py -import pytest - -from labelbox import Webhook - - -def test_webhook_create_update(project, rand_gen): - client = project.client - url = "https:/" + rand_gen(str) - secret = rand_gen(str) - topics = [Webhook.LABEL_CREATED, Webhook.LABEL_DELETED] - webhook = Webhook.create(client, topics, url, secret, project) - - assert webhook.project() == project - assert webhook.organization() == client.get_organization() - assert webhook.url == url - assert webhook.topics == topics - assert webhook.status == Webhook.ACTIVE - assert list(project.webhooks()) == [webhook] - assert webhook in set(client.get_organization().webhooks()) - - webhook.update(status=Webhook.REVOKED, topics=[Webhook.LABEL_UPDATED]) - assert webhook.topics == [Webhook.LABEL_UPDATED] - assert webhook.status == Webhook.REVOKED - - with pytest.raises(ValueError) as exc_info: - webhook.update(status="invalid..") - valid_webhook_statuses = {item.value for item in Webhook.Status} - assert str(exc_info.value) == \ - f"Value `invalid..` does not exist in supported values. Expected one of {valid_webhook_statuses}" - - with pytest.raises(ValueError) as exc_info: - webhook.update(topics=["invalid.."]) - valid_webhook_topics = {item.value for item in Webhook.Topic} - assert str(exc_info.value) == \ - f"Value `invalid..` does not exist in supported values. Expected one of {valid_webhook_topics}" - - with pytest.raises(TypeError) as exc_info: - webhook.update(topics="invalid..") - assert str(exc_info.value) == \ - "Topics must be List[Webhook.Topic]. Found `invalid..`" - - webhook.delete() - - -def test_webhook_create_with_no_secret(project, rand_gen): - client = project.client - secret = "" - url = "https:/" + rand_gen(str) - topics = [] - - with pytest.raises(ValueError) as exc_info: - Webhook.create(client, topics, url, secret, project) - assert str(exc_info.value) == \ - "Secret must be a non-empty string." - - -def test_webhook_create_with_no_topics(project, rand_gen): - client = project.client - secret = rand_gen(str) - url = "https:/" + rand_gen(str) - topics = [] - - with pytest.raises(ValueError) as exc_info: - Webhook.create(client, topics, url, secret, project) - assert str(exc_info.value) == \ - "Topics must be a non-empty list." - - -def test_webhook_create_with_no_url(project, rand_gen): - client = project.client - secret = rand_gen(str) - url = "" - topics = [Webhook.LABEL_CREATED, Webhook.LABEL_DELETED] - - with pytest.raises(ValueError) as exc_info: - Webhook.create(client, topics, url, secret, project) - assert str(exc_info.value) == \ - "URL must be a non-empty string." - ----- -tests/integration/test_label.py -import time - -import pytest -import requests -import os - -from labelbox import Label - - -def test_labels(configured_project_with_label): - project, _, data_row, label = configured_project_with_label - - assert list(project.labels()) == [label] - assert list(data_row.labels()) == [label] - - assert label.project() == project - assert label.data_row() == data_row - assert label.created_by() == label.client.get_user() - - label.delete() - - # TODO: Added sleep to account for ES from catching up to deletion. - # Need a better way to query labels in `project.labels()`, because currently, - # it intermittently takes too long to sync, causing flaky SDK tests - time.sleep(5) - - assert list(project.labels()) == [] - assert list(data_row.labels()) == [] - - -# TODO: Skipping this test in staging due to label not updating -@pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem" or - os.environ['LABELBOX_TEST_ENVIRON'] == "staging" or - os.environ['LABELBOX_TEST_ENVIRON'] == "local" or - os.environ['LABELBOX_TEST_ENVIRON'] == "custom", - reason="does not work for onprem") -def test_label_update(configured_project_with_label): - _, _, _, label = configured_project_with_label - label.update(label="something else") - assert label.label == "something else" - - -def test_label_filter_order(configured_project_with_label): - project, _, _, label = configured_project_with_label - - l1 = label - project.create_label() - l2 = next(project.labels()) - - assert set(project.labels()) == {l1, l2} - - assert list(project.labels(order_by=Label.created_at.asc)) == [l1, l2] - assert list(project.labels(order_by=Label.created_at.desc)) == [l2, l1] - - -def test_label_bulk_deletion(configured_project_with_label): - project, _, _, _ = configured_project_with_label - - for _ in range(2): - #only run twice, already have one label in the fixture - project.create_label() - labels = project.labels() - l1 = next(labels) - l2 = next(labels) - l3 = next(labels) - - assert set(project.labels()) == {l1, l2, l3} - - Label.bulk_delete([l1, l3]) - - # TODO: the sdk client should really abstract all these timing issues away - # but for now bulk deletes take enough time that this test is flaky - # add sleep here to avoid that flake - time.sleep(5) - - assert set(project.labels()) == {l2} - ----- -tests/integration/test_task.py -import json -import pytest -import collections.abc -from labelbox import DataRow -from labelbox.schema.data_row_metadata import DataRowMetadataField -from utils import INTEGRATION_SNAPSHOT_DIRECTORY - -TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh" - - -def test_task_errors(dataset, image_url, snapshot): - client = dataset.client - task = dataset.create_data_rows([ - { - DataRow.row_data: - image_url, - DataRow.metadata_fields: [ - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, - value='some msg'), - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, - value='some msg 2') - ] - }, - ]) - - assert task in client.get_user().created_tasks() - task.wait_till_done() - - assert len(task.failed_data_rows) == 1 - assert "A schemaId can only be specified once per DataRow : [cko8s9r5v0001h2dk9elqdidh]" in task.failed_data_rows[ - 0]['message'] - assert len(task.failed_data_rows[0]['failedDataRows'][0]['metadata']) == 2 - - -def test_task_success_json(dataset, image_url, snapshot): - client = dataset.client - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url, - }, - ]) - assert task in client.get_user().created_tasks() - task.wait_till_done() - assert task.status == "COMPLETE" - assert task.errors is None - assert task.result is not None - assert isinstance(task.result, collections.abc.Sequence) - assert task.result_url is not None - assert isinstance(task.result_url, str) - task_result = task.result[0] - assert 'id' in task_result and isinstance(task_result['id'], str) - assert 'row_data' in task_result and isinstance(task_result['row_data'], - str) - snapshot.snapshot_dir = INTEGRATION_SNAPSHOT_DIRECTORY - task_result['id'] = 'DUMMY_ID' - task_result['row_data'] = 'https://dummy.url' - snapshot.assert_match(json.dumps(task_result), - 'test_task.test_task_success_json.json') - assert len(task.result) - - -def test_task_success_label_export(client, configured_project_with_label): - project, _, _, _ = configured_project_with_label - # TODO: Move to export_v2 - project.export_labels() - user = client.get_user() - task = None - for task in user.created_tasks(): - if task.name != 'JSON Import': - break - - with pytest.raises(ValueError) as exc_info: - task.result - assert str(exc_info.value).startswith("Task result is only supported for") - ----- -tests/integration/test_send_to_annotate.py -from labelbox import UniqueIds, OntologyBuilder, LabelingFrontend -from labelbox.schema.conflict_resolution_strategy import ConflictResolutionStrategy - - -def test_send_to_annotate_include_annotations( - client, configured_batch_project_with_label, project_pack, ontology): - [source_project, _, data_row, _] = configured_batch_project_with_label - destination_project = project_pack[0] - - src_ontology = source_project.ontology() - destination_project.setup_editor(ontology) - - # build an ontology mapping using the top level tools - src_feature_schema_ids = list( - tool.feature_schema_id for tool in src_ontology.tools()) - dest_ontology = destination_project.ontology() - dest_feature_schema_ids = list( - tool.feature_schema_id for tool in dest_ontology.tools()) - # create a dictionary of feature schema id to itself - ontology_mapping = dict(zip(src_feature_schema_ids, - dest_feature_schema_ids)) - - try: - queues = destination_project.task_queues() - initial_review_task = next( - q for q in queues if q.name == "Initial review task") - - # Send the data row to the new project - task = client.send_to_annotate_from_catalog( - destination_project_id=destination_project.uid, - task_queue_id=initial_review_task.uid, - batch_name="test-batch", - data_rows=UniqueIds([data_row.uid]), - params={ - "source_project_id": - source_project.uid, - "annotations_ontology_mapping": - ontology_mapping, - "override_existing_annotations_rule": - ConflictResolutionStrategy.OverrideWithAnnotations - }) - - task.wait_till_done() - - # Check that the data row was sent to the new project - destination_batches = list(destination_project.batches()) - assert len(destination_batches) == 1 - - destination_data_rows = list(destination_batches[0].export_data_rows()) - assert len(destination_data_rows) == 1 - assert destination_data_rows[0].uid == data_row.uid - - # Verify annotations were copied into the destination project - destination_project_labels = (list(destination_project.labels())) - assert len(destination_project_labels) == 1 - finally: - destination_project.delete() - ----- -tests/integration/test_labeling_parameter_overrides.py -import pytest -from labelbox import DataRow -from labelbox.schema.identifiable import GlobalKey, UniqueId -from labelbox.schema.identifiables import GlobalKeys, UniqueIds - - -def test_labeling_parameter_overrides(consensus_project_with_batch): - [project, _, data_rows] = consensus_project_with_batch - - init_labeling_parameter_overrides = list( - project.labeling_parameter_overrides()) - assert len(init_labeling_parameter_overrides) == 3 - assert {o.number_of_labels for o in init_labeling_parameter_overrides - } == {1, 1, 1} - assert {o.priority for o in init_labeling_parameter_overrides} == {5, 5, 5} - assert {o.data_row().uid for o in init_labeling_parameter_overrides - } == {data_rows[0].uid, data_rows[1].uid, data_rows[2].uid} - - data = [(data_rows[0], 4, 2), (data_rows[1], 3)] - success = project.set_labeling_parameter_overrides(data) - assert success - - updated_overrides = list(project.labeling_parameter_overrides()) - assert len(updated_overrides) == 3 - assert {o.number_of_labels for o in updated_overrides} == {1, 1, 1} - assert {o.priority for o in updated_overrides} == {4, 3, 5} - - for override in updated_overrides: - assert isinstance(override.data_row(), DataRow) - - data = [(UniqueId(data_rows[0].uid), 1, 2), (UniqueId(data_rows[1].uid), 2), - (UniqueId(data_rows[2].uid), 3)] - success = project.set_labeling_parameter_overrides(data) - assert success - updated_overrides = list(project.labeling_parameter_overrides()) - assert len(updated_overrides) == 3 - assert {o.number_of_labels for o in updated_overrides} == {1, 1, 1} - assert {o.priority for o in updated_overrides} == {1, 2, 3} - - data = [(GlobalKey(data_rows[0].global_key), 2, 2), - (GlobalKey(data_rows[1].global_key), 3, 3), - (GlobalKey(data_rows[2].global_key), 4)] - success = project.set_labeling_parameter_overrides(data) - assert success - updated_overrides = list(project.labeling_parameter_overrides()) - assert len(updated_overrides) == 3 - assert {o.number_of_labels for o in updated_overrides} == {1, 1, 1} - assert {o.priority for o in updated_overrides} == {2, 3, 4} - - with pytest.raises(TypeError) as exc_info: - data = [(data_rows[2], "a_string", 3)] - project.set_labeling_parameter_overrides(data) - assert str(exc_info.value) == \ - f"Priority must be an int. Found for data_row_identifier {data_rows[2].uid}" - - with pytest.raises(TypeError) as exc_info: - data = [(data_rows[2].uid, 1)] - project.set_labeling_parameter_overrides(data) - assert str(exc_info.value) == \ - f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found for data_row_identifier {data_rows[2].uid}" - - -def test_set_labeling_priority(consensus_project_with_batch): - [project, _, data_rows] = consensus_project_with_batch - - init_labeling_parameter_overrides = list( - project.labeling_parameter_overrides()) - assert len(init_labeling_parameter_overrides) == 3 - assert {o.priority for o in init_labeling_parameter_overrides} == {5, 5, 5} - - data = [data_row.uid for data_row in data_rows] - success = project.update_data_row_labeling_priority(data, 1) - lo = list(project.labeling_parameter_overrides()) - assert success - assert len(lo) == 3 - assert {o.priority for o in lo} == {1, 1, 1} - - data = [data_row.uid for data_row in data_rows] - success = project.update_data_row_labeling_priority(UniqueIds(data), 2) - lo = list(project.labeling_parameter_overrides()) - assert success - assert len(lo) == 3 - assert {o.priority for o in lo} == {2, 2, 2} - - data = [data_row.global_key for data_row in data_rows] - success = project.update_data_row_labeling_priority(GlobalKeys(data), 3) - lo = list(project.labeling_parameter_overrides()) - assert success - assert len(lo) == 3 - assert {o.priority for o in lo} == {3, 3, 3} - ----- -tests/integration/test_sorting.py -import pytest - -from labelbox import Project - - -@pytest.mark.xfail(reason="Sorting not supported on top-level fetches") -def test_top_level_sorting(client): - client.get_projects(order_by=Project.name.asc) - ----- -tests/integration/test_dates.py -from datetime import datetime, timedelta, timezone - - -def test_dates(project): - assert isinstance(project.created_at, datetime) - assert isinstance(project.updated_at, datetime) - - project.update(setup_complete=datetime.now()) - assert isinstance(project.setup_complete, datetime) - - -def test_utc_conversion(project): - assert isinstance(project.created_at, datetime) - assert project.created_at.tzinfo == timezone.utc - - # Update with a datetime without TZ info. - project.update(setup_complete=datetime.now()) - # Check that the server-side, UTC date is the same as local date - # converted locally to UTC. - diff = project.setup_complete - datetime.now().astimezone(timezone.utc) - assert abs(diff) < timedelta(minutes=1) - - # Update with a datetime with TZ info - tz = timezone(timedelta(hours=6)) # +6 timezone - project.update(setup_complete=datetime.utcnow().replace(tzinfo=tz)) - diff = datetime.utcnow() - project.setup_complete.replace(tzinfo=None) - assert diff > timedelta(hours=5, minutes=58) - ----- -tests/integration/test_user_and_org.py -from labelbox.schema.project import Project - - -def test_user(client): - user = client.get_user() - assert user.uid is not None - assert user.organization() == client.get_organization() - - -def test_organization(client): - organization = client.get_organization() - assert organization.uid is not None - assert client.get_user() in set(organization.users()) - - -def test_user_and_org_projects(client, project): - user = client.get_user() - org = client.get_organization() - user_project = user.projects(where=Project.uid == project.uid) - org_project = org.projects(where=Project.uid == project.uid) - - assert user_project - assert org_project ----- -tests/integration/test_client_errors.py -from multiprocessing.dummy import Pool -import os -import time -import pytest -from google.api_core.exceptions import RetryError - -from labelbox import Project, Dataset, User -import labelbox.client -import labelbox.exceptions - - -def test_missing_api_key(): - key = os.environ.get(labelbox.client._LABELBOX_API_KEY, None) - if key is not None: - del os.environ[labelbox.client._LABELBOX_API_KEY] - - with pytest.raises(labelbox.exceptions.AuthenticationError) as excinfo: - labelbox.client.Client() - - assert excinfo.value.message == "Labelbox API key not provided" - - if key is not None: - os.environ[labelbox.client._LABELBOX_API_KEY] = key - - -def test_bad_key(rand_gen): - bad_key = "BAD_KEY_" + rand_gen(str) - client = labelbox.client.Client(api_key=bad_key) - - with pytest.raises(labelbox.exceptions.AuthenticationError) as excinfo: - client.create_project(name=rand_gen(str)) - - -def test_syntax_error(client): - with pytest.raises(labelbox.exceptions.InvalidQueryError) as excinfo: - client.execute("asda", check_naming=False) - assert excinfo.value.message.startswith("Syntax Error:") - - -def test_semantic_error(client): - with pytest.raises(labelbox.exceptions.InvalidQueryError) as excinfo: - client.execute("query {bbb {id}}", check_naming=False) - assert excinfo.value.message.startswith("Cannot query field \"bbb\"") - - -def test_timeout_error(client, project): - with pytest.raises(RetryError) as excinfo: - query_str = """query getOntology { - project (where: {id: %s}) { - ontology { - normalized - } - } - } """ % (project.uid) - - # Setting connect timeout to 30s, and read timeout to 0.01s - client.execute(query_str, check_naming=False, timeout=(30.0, 0.01)) - - -def test_query_complexity_error(client): - with pytest.raises(labelbox.exceptions.ValidationFailedError) as excinfo: - client.execute("{projects {datasets {dataRows {labels {id}}}}}", - check_naming=False) - assert excinfo.value.message == "Query complexity limit exceeded" - - -def test_resource_not_found_error(client): - with pytest.raises(labelbox.exceptions.ResourceNotFoundError): - client.get_project("invalid project ID") - - -def test_network_error(client): - client = labelbox.client.Client(api_key=client.api_key, - endpoint="not_a_valid_URL") - - with pytest.raises(labelbox.exceptions.NetworkError) as excinfo: - client.create_project(name="Project name") - - -def test_invalid_attribute_error( - client, - rand_gen, -): - # Creation - with pytest.raises(labelbox.exceptions.InvalidAttributeError) as excinfo: - client.create_project(name="Name", invalid_field="Whatever") - assert excinfo.value.db_object_type == Project - assert excinfo.value.field == "invalid_field" - - # Update - project = client.create_project(name=rand_gen(str)) - with pytest.raises(labelbox.exceptions.InvalidAttributeError) as excinfo: - project.update(invalid_field="Whatever") - assert excinfo.value.db_object_type == Project - assert excinfo.value.field == "invalid_field" - - # Top-level-fetch - with pytest.raises(labelbox.exceptions.InvalidAttributeError) as excinfo: - client.get_projects(where=User.email == "email") - assert excinfo.value.db_object_type == Project - assert excinfo.value.field == {User.email} - - -@pytest.mark.skip("timeouts cause failure before rate limit") -def test_api_limit_error(client): - - def get(arg): - try: - return client.get_user() - except labelbox.exceptions.ApiLimitError as e: - return e - - # Rate limited at 1500 + buffer - n = 1600 - # max of 30 concurrency before the service becomes unavailable - with Pool(30) as pool: - start = time.time() - results = list(pool.imap(get, range(n)), total=n) - elapsed = time.time() - start - - assert elapsed < 60, "Didn't finish fast enough" - assert labelbox.exceptions.ApiLimitError in {type(r) for r in results} - - # Sleep at the end of this test to allow other tests to execute. - time.sleep(60) - ----- -tests/integration/test_delegated_access.py -import os - -import requests -import pytest - -from labelbox import Client - - -@pytest.mark.skip( - reason= - "Google credentials are being updated for this test, disabling till it's all sorted out" -) -@pytest.mark.skipif(not os.environ.get('DA_GCP_LABELBOX_API_KEY'), - reason="DA_GCP_LABELBOX_API_KEY not found") -def test_default_integration(): - """ - This tests assumes the following: - 1. gcp delegated access is configured to work with jtso-gcs-sdk-da-tests - 2. the integration name is gcs sdk test bucket - 3. This integration is the default - - Currently tests against: - Org ID: cl269lvvj78b50zau34s4550z - Email: jtso+gcp_sdk_tests@labelbox.com""" - client = Client(api_key=os.environ.get("DA_GCP_LABELBOX_API_KEY")) - ds = client.create_dataset(name="new_ds") - dr = ds.create_data_row( - row_data= - "gs://jtso-gcs-sdk-da-tests/nikita-samokhin-D6QS6iv_CTY-unsplash.jpg") - assert requests.get(dr.row_data).status_code == 200 - assert ds.iam_integration().name == "gcs sdk test bucket" - ds.delete() - - -@pytest.mark.skip( - reason= - "Google credentials are being updated for this test, disabling till it's all sorted out" -) -@pytest.mark.skipif(not os.environ.get("DA_GCP_LABELBOX_API_KEY"), - reason="DA_GCP_LABELBOX_API_KEY not found") -def test_non_default_integration(): - """ - This tests assumes the following: - 1. aws delegated access is configured to work with lbox-test-bucket - 2. an integration called aws is available to the org - - Currently tests against: - Org ID: cl26d06tk0gch10901m7jeg9v - Email: jtso+aws_sdk_tests@labelbox.com - """ - client = Client(api_key=os.environ.get("DA_GCP_LABELBOX_API_KEY")) - integrations = client.get_organization().get_iam_integrations() - integration = [ - inte for inte in integrations if 'aws-da-test-bucket' in inte.name - ][0] - assert integration.valid - ds = client.create_dataset(iam_integration=integration, name="new_ds") - assert ds.iam_integration().name == "aws-da-test-bucket" - dr = ds.create_data_row( - row_data= - "https://jtso-aws-da-sdk-tests.s3.us-east-2.amazonaws.com/adrian-yu-qkN4D3Rf1gw-unsplash.jpg" - ) - assert requests.get(dr.row_data).status_code == 200 - ds.delete() - - -def test_no_integration(client, image_url): - ds = client.create_dataset(iam_integration=None, name="new_ds") - assert ds.iam_integration() is None - dr = ds.create_data_row(row_data=image_url) - assert requests.get(dr.row_data).status_code == 200 - ds.delete() - - -def test_no_default_integration(client): - ds = client.create_dataset(name="new_ds") - assert ds.iam_integration() is None - ds.delete() - ----- -tests/integration/test_data_row_delete_metadata.py -from datetime import datetime -import uuid - -import pytest - -from labelbox import DataRow, Dataset -from labelbox.exceptions import MalformedQueryException -from labelbox.schema.data_row_metadata import DataRowMetadataField, DataRowMetadata, DataRowMetadataKind, DeleteDataRowMetadata -from labelbox.schema.identifiable import GlobalKey, UniqueId - -INVALID_SCHEMA_ID = "1" * 25 -FAKE_SCHEMA_ID = "0" * 25 -FAKE_DATAROW_ID = "D" * 25 -SPLIT_SCHEMA_ID = "cko8sbczn0002h2dkdaxb5kal" -TRAIN_SPLIT_ID = "cko8sbscr0003h2dk04w86hof" -TEST_SPLIT_ID = "cko8scbz70005h2dkastwhgqt" -TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh" -CAPTURE_DT_SCHEMA_ID = "cko8sdzv70006h2dk8jg64zvb" -CUSTOM_TEXT_SCHEMA_NAME = 'custom_text' - -FAKE_NUMBER_FIELD = { - "id": FAKE_SCHEMA_ID, - "name": "number", - "kind": 'CustomMetadataNumber', - "reserved": False -} - - -@pytest.fixture -def mdo(client): - mdo = client.get_data_row_metadata_ontology() - try: - mdo.create_schema(CUSTOM_TEXT_SCHEMA_NAME, DataRowMetadataKind.string) - except MalformedQueryException: - # Do nothing if already exists - pass - mdo._raw_ontology = mdo._get_ontology() - mdo._raw_ontology.append(FAKE_NUMBER_FIELD) - mdo._build_ontology() - yield mdo - - -@pytest.fixture -def big_dataset(dataset: Dataset, image_url): - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image", - "global_key": str(uuid.uuid4()) - }, - ] * 5) - task.wait_till_done() - - yield dataset - - -def make_metadata(dr_id: str = None, gk: str = None) -> DataRowMetadata: - msg = "A message" - time = datetime.utcnow() - - metadata = DataRowMetadata( - global_key=gk, - data_row_id=dr_id, - fields=[ - DataRowMetadataField(schema_id=SPLIT_SCHEMA_ID, - value=TEST_SPLIT_ID), - DataRowMetadataField(schema_id=CAPTURE_DT_SCHEMA_ID, value=time), - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value=msg), - ]) - return metadata - - -def make_named_metadata(dr_id) -> DataRowMetadata: - msg = "A message" - time = datetime.utcnow() - - metadata = DataRowMetadata(data_row_id=dr_id, - fields=[ - DataRowMetadataField(name='split', - value=TEST_SPLIT_ID), - DataRowMetadataField(name='captureDateTime', - value=time), - DataRowMetadataField( - name=CUSTOM_TEXT_SCHEMA_NAME, value=msg), - ]) - return metadata - - -def test_bulk_delete_datarow_metadata(data_row, mdo): - """test bulk deletes for all fields""" - metadata = make_metadata(data_row.uid) - mdo.bulk_upsert([metadata]) - assert len(mdo.bulk_export([data_row.uid])[0].fields) - upload_ids = [m.schema_id for m in metadata.fields[:-2]] - mdo.bulk_delete( - [DeleteDataRowMetadata(data_row_id=data_row.uid, fields=upload_ids)]) - remaining_ids = set( - [f.schema_id for f in mdo.bulk_export([data_row.uid])[0].fields]) - assert not len(remaining_ids.intersection(set(upload_ids))) - - -@pytest.fixture -def data_row_unique_id(data_row): - return UniqueId(data_row.uid) - - -@pytest.fixture -def data_row_global_key(data_row): - return GlobalKey(data_row.global_key) - - -@pytest.fixture -def data_row_id_as_str(data_row): - return data_row.uid - - -@pytest.mark.parametrize( - 'data_row_for_delete', - ['data_row_id_as_str', 'data_row_unique_id', 'data_row_global_key']) -def test_bulk_delete_datarow_metadata(data_row_for_delete, data_row, mdo, - request): - """test bulk deletes for all fields""" - metadata = make_metadata(data_row.uid) - mdo.bulk_upsert([metadata]) - assert len(mdo.bulk_export([data_row.uid])[0].fields) - upload_ids = [m.schema_id for m in metadata.fields[:-2]] - mdo.bulk_delete([ - DeleteDataRowMetadata( - data_row_id=request.getfixturevalue(data_row_for_delete), - fields=upload_ids) - ]) - remaining_ids = set( - [f.schema_id for f in mdo.bulk_export([data_row.uid])[0].fields]) - assert not len(remaining_ids.intersection(set(upload_ids))) - - -@pytest.mark.parametrize( - 'data_row_for_delete', - ['data_row_id_as_str', 'data_row_unique_id', 'data_row_global_key']) -def test_bulk_partial_delete_datarow_metadata(data_row_for_delete, data_row, - mdo, request): - """Delete a single from metadata""" - n_fields = len(mdo.bulk_export([data_row.uid])[0].fields) - metadata = make_metadata(data_row.uid) - mdo.bulk_upsert([metadata]) - - assert len(mdo.bulk_export( - [data_row.uid])[0].fields) == (n_fields + len(metadata.fields)) - - mdo.bulk_delete([ - DeleteDataRowMetadata( - data_row_id=request.getfixturevalue(data_row_for_delete), - fields=[TEXT_SCHEMA_ID]) - ]) - fields = [f for f in mdo.bulk_export([data_row.uid])[0].fields] - assert len(fields) == (len(metadata.fields) - 1) - - -@pytest.fixture -def data_row_unique_ids(big_dataset): - deletes = [] - data_row_ids = [dr.uid for dr in big_dataset.data_rows()] - - for data_row_id in data_row_ids: - deletes.append( - DeleteDataRowMetadata( - data_row_id=UniqueId(data_row_id), - fields=[SPLIT_SCHEMA_ID, CAPTURE_DT_SCHEMA_ID])) - return deletes - - -@pytest.fixture -def data_row_ids_as_str(big_dataset): - deletes = [] - data_row_ids = [dr.uid for dr in big_dataset.data_rows()] - - for data_row_id in data_row_ids: - deletes.append( - DeleteDataRowMetadata( - data_row_id=data_row_id, - fields=[SPLIT_SCHEMA_ID, CAPTURE_DT_SCHEMA_ID])) - return deletes - - -@pytest.fixture -def data_row_global_keys(big_dataset): - deletes = [] - global_keys = [dr.global_key for dr in big_dataset.data_rows()] - - for data_row_id in global_keys: - deletes.append( - DeleteDataRowMetadata( - data_row_id=GlobalKey(data_row_id), - fields=[SPLIT_SCHEMA_ID, CAPTURE_DT_SCHEMA_ID])) - return deletes - - -@pytest.mark.parametrize( - 'data_rows_for_delete', - ['data_row_ids_as_str', 'data_row_unique_ids', 'data_row_global_keys']) -def test_large_bulk_delete_datarow_metadata(data_rows_for_delete, big_dataset, - mdo, request): - metadata = [] - data_row_ids = [dr.uid for dr in big_dataset.data_rows()] - for data_row_id in data_row_ids: - metadata.append( - DataRowMetadata(data_row_id=data_row_id, - fields=[ - DataRowMetadataField(schema_id=SPLIT_SCHEMA_ID, - value=TEST_SPLIT_ID), - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, - value="test-message") - ])) - errors = mdo.bulk_upsert(metadata) - assert len(errors) == 0 - - deletes = request.getfixturevalue(data_rows_for_delete) - errors = mdo.bulk_delete(deletes) - - assert len(errors) == len(data_row_ids) - for error in errors: - assert error.fields == [CAPTURE_DT_SCHEMA_ID] - assert error.error == 'Schema did not exist' - - for data_row_id in data_row_ids: - fields = [f for f in mdo.bulk_export([data_row_id])[0].fields] - assert len(fields) == 1, fields - assert SPLIT_SCHEMA_ID not in [field.schema_id for field in fields] - - -@pytest.mark.parametrize( - 'data_row_for_delete', - ['data_row_id_as_str', 'data_row_unique_id', 'data_row_global_key']) -def test_bulk_delete_datarow_enum_metadata(data_row_for_delete, - data_row: DataRow, mdo, request): - """test bulk deletes for non non fields""" - metadata = make_metadata(data_row.uid) - metadata.fields = [ - m for m in metadata.fields if m.schema_id == SPLIT_SCHEMA_ID - ] - mdo.bulk_upsert([metadata]) - - exported = mdo.bulk_export([data_row.uid])[0].fields - assert len(exported) == len( - set([x.schema_id for x in metadata.fields] + - [x.schema_id for x in exported])) - - mdo.bulk_delete([ - DeleteDataRowMetadata( - data_row_id=request.getfixturevalue(data_row_for_delete), - fields=[SPLIT_SCHEMA_ID]) - ]) - exported = mdo.bulk_export([data_row.uid])[0].fields - assert len(exported) == 0 - - -@pytest.mark.parametrize( - 'data_row_for_delete', - ['data_row_id_as_str', 'data_row_unique_id', 'data_row_global_key']) -def test_delete_non_existent_schema_id(data_row_for_delete, data_row, mdo, - request): - res = mdo.bulk_delete([ - DeleteDataRowMetadata( - data_row_id=request.getfixturevalue(data_row_for_delete), - fields=[SPLIT_SCHEMA_ID]) - ]) - assert len(res) == 1 - assert res[0].fields == [SPLIT_SCHEMA_ID] - assert res[0].error == 'Schema did not exist' - ----- -tests/integration/test_labeler_performance.py -from datetime import datetime, timezone, timedelta -import pytest -import os - - -@pytest.mark.skipif( - condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", - reason="longer runtime than expected for onprem. unskip when resolved.") -def test_labeler_performance(configured_project_with_label): - project, _, _, _ = configured_project_with_label - - labeler_performance = list(project.labeler_performance()) - assert len(labeler_performance) == 1 - my_performance = labeler_performance[0] - assert my_performance.user == project.client.get_user() - assert my_performance.count == 1 - assert isinstance(my_performance.last_activity_time, datetime) - now_utc = datetime.now().astimezone(timezone.utc) - assert timedelta(0) < now_utc - my_performance.last_activity_time < \ - timedelta(seconds=60) - ----- -tests/integration/test_project_setup.py -from datetime import datetime, timedelta, timezone -import json -import time - -import pytest - -from labelbox import LabelingFrontend -from labelbox.exceptions import InvalidQueryError, ResourceConflict - - -def simple_ontology(): - classifications = [{ - "name": "test_ontology", - "instructions": "Which class is this?", - "type": "radio", - "options": [{ - "value": c, - "label": c - } for c in ["one", "two", "three"]], - "required": True, - }] - - return {"tools": [], "classifications": classifications} - - -def test_project_setup(project) -> None: - client = project.client - labeling_frontends = list( - client.get_labeling_frontends(where=LabelingFrontend.name == 'Editor')) - assert len(labeling_frontends) - labeling_frontend = labeling_frontends[0] - - time.sleep(3) - now = datetime.now().astimezone(timezone.utc) - - project.setup(labeling_frontend, simple_ontology()) - assert now - project.setup_complete <= timedelta(seconds=3) - assert now - project.last_activity_time <= timedelta(seconds=3) - - assert project.labeling_frontend() == labeling_frontend - options = list(project.labeling_frontend_options()) - assert len(options) == 1 - options = options[0] - # TODO ensure that LabelingFrontendOptions can be obtaind by ID - with pytest.raises(InvalidQueryError): - assert options.labeling_frontend() == labeling_frontend - assert options.project() == project - assert options.organization() == client.get_organization() - assert options.customization_options == json.dumps(simple_ontology()) - assert project.organization() == client.get_organization() - assert project.created_by() == client.get_user() - - -def test_project_editor_setup(client, project, rand_gen): - ontology_name = f"test_project_editor_setup_ontology_name-{rand_gen(str)}" - ontology = client.create_ontology(ontology_name, simple_ontology()) - now = datetime.now().astimezone(timezone.utc) - project.setup_editor(ontology) - assert now - project.setup_complete <= timedelta(seconds=3) - assert now - project.last_activity_time <= timedelta(seconds=3) - assert project.labeling_frontend().name == "Editor" - assert project.organization() == client.get_organization() - assert project.created_by() == client.get_user() - assert project.ontology().name == ontology_name - # Make sure that setup only creates one ontology - time.sleep(3) # Search takes a second - assert [ontology.name for ontology in client.get_ontologies(ontology_name) - ] == [ontology_name] - - -def test_project_editor_setup_cant_call_multiple_times(client, project, - rand_gen): - ontology_name = f"test_project_editor_setup_ontology_name-{rand_gen(str)}" - ontology = client.create_ontology(ontology_name, simple_ontology()) - project.setup_editor(ontology) - with pytest.raises(ResourceConflict): - project.setup_editor(ontology) - ----- -tests/integration/test_data_rows.py -from tempfile import NamedTemporaryFile -import time -import uuid -from datetime import datetime -import json -from labelbox.schema.media_type import MediaType - -import pytest -import requests - -from labelbox import DataRow -from labelbox.exceptions import MalformedQueryException -from labelbox.schema.task import Task -from labelbox.schema.data_row_metadata import DataRowMetadataField, DataRowMetadataKind -import labelbox.exceptions - -SPLIT_SCHEMA_ID = "cko8sbczn0002h2dkdaxb5kal" -TEST_SPLIT_ID = "cko8scbz70005h2dkastwhgqt" -TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh" -CAPTURE_DT_SCHEMA_ID = "cko8sdzv70006h2dk8jg64zvb" -EXPECTED_METADATA_SCHEMA_IDS = [ - SPLIT_SCHEMA_ID, TEST_SPLIT_ID, TEXT_SCHEMA_ID, CAPTURE_DT_SCHEMA_ID -].sort() -CUSTOM_TEXT_SCHEMA_NAME = "custom_text" - - -@pytest.fixture -def mdo(client): - mdo = client.get_data_row_metadata_ontology() - try: - mdo.create_schema(CUSTOM_TEXT_SCHEMA_NAME, DataRowMetadataKind.string) - except MalformedQueryException: - # Do nothing if already exists - pass - mdo._raw_ontology = mdo._get_ontology() - mdo._build_ontology() - yield mdo - - -@pytest.fixture -def conversational_content(): - return { - 'row_data': { - "messages": [{ - "messageId": "message-0", - "timestampUsec": 1530718491, - "content": "I love iphone! i just bought new iphone! 🥰 📲", - "user": { - "userId": "Bot 002", - "name": "Bot" - }, - "align": "left", - "canLabel": False - }], - "version": 1, - "type": "application/vnd.labelbox.conversational" - } - } - - -@pytest.fixture -def tile_content(): - return { - "row_data": { - "tileLayerUrl": - "https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/{z}/{x}/{y}.png", - "bounds": [[19.405662413477728, -99.21052827588443], - [19.400498983095076, -99.20534818927473]], - "minZoom": - 12, - "maxZoom": - 20, - "epsg": - "EPSG4326", - "alternativeLayers": [{ - "tileLayerUrl": - "https://api.mapbox.com/styles/v1/mapbox/satellite-streets-v11/tiles/{z}/{x}/{y}?access_token=pk.eyJ1IjoibWFwYm94IiwiYSI6ImNpejY4NXVycTA2emYycXBndHRqcmZ3N3gifQ.rJcFIG214AriISLbB6B5aw", - "name": - "Satellite" - }, { - "tileLayerUrl": - "https://api.mapbox.com/styles/v1/mapbox/navigation-guidance-night-v4/tiles/{z}/{x}/{y}?access_token=pk.eyJ1IjoibWFwYm94IiwiYSI6ImNpejY4NXVycTA2emYycXBndHRqcmZ3N3gifQ.rJcFIG214AriISLbB6B5aw", - "name": - "Guidance" - }] - } - } - - -def make_metadata_fields(): - msg = "A message" - time = datetime.utcnow() - - fields = [ - DataRowMetadataField(schema_id=SPLIT_SCHEMA_ID, value=TEST_SPLIT_ID), - DataRowMetadataField(schema_id=CAPTURE_DT_SCHEMA_ID, value=time), - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value=msg), - ] - return fields - - -def make_metadata_fields_dict(): - msg = "A message" - time = datetime.utcnow() - - fields = [{ - "schema_id": SPLIT_SCHEMA_ID, - "value": TEST_SPLIT_ID - }, { - "schema_id": CAPTURE_DT_SCHEMA_ID, - "value": time - }, { - "schema_id": TEXT_SCHEMA_ID, - "value": msg - }] - return fields - - -def test_get_data_row_by_global_key(data_row_and_global_key, client, rand_gen): - _, global_key = data_row_and_global_key - data_row = client.get_data_row_by_global_key(global_key) - assert type(data_row) == DataRow - assert data_row.global_key == global_key - - -def test_get_data_row(data_row, client): - assert client.get_data_row(data_row.uid) - - -def test_create_invalid_aws_data_row(dataset, client): - with pytest.raises(labelbox.exceptions.InvalidQueryError) as exc: - dataset.create_data_row(row_data="s3://labelbox-public-data/invalid") - assert "s3" in exc.value.message - - with pytest.raises(labelbox.exceptions.InvalidQueryError) as exc: - dataset.create_data_rows([{ - "row_data": "s3://labelbox-public-data/invalid" - }]) - assert "s3" in exc.value.message - - -def test_lookup_data_rows(client, dataset): - uid = str(uuid.uuid4()) - # 1 external id : 1 uid - dr = dataset.create_data_row(row_data="123", external_id=uid) - lookup = client.get_data_row_ids_for_external_ids([uid]) - assert len(lookup) == 1 - assert lookup[uid][0] == dr.uid - # 2 external ids : 1 uid - uid2 = str(uuid.uuid4()) - dr2 = dataset.create_data_row(row_data="123", external_id=uid2) - lookup = client.get_data_row_ids_for_external_ids([uid, uid2]) - assert len(lookup) == 2 - assert all([len(x) == 1 for x in lookup.values()]) - assert lookup[uid][0] == dr.uid - assert lookup[uid2][0] == dr2.uid - # 1 external id : 2 uid - dr3 = dataset.create_data_row(row_data="123", external_id=uid2) - lookup = client.get_data_row_ids_for_external_ids([uid2]) - assert len(lookup) == 1 - assert len(lookup[uid2]) == 2 - assert lookup[uid2][0] == dr2.uid - assert lookup[uid2][1] == dr3.uid - # Empty args - lookup = client.get_data_row_ids_for_external_ids([]) - assert len(lookup) == 0 - # Non matching - lookup = client.get_data_row_ids_for_external_ids([str(uuid.uuid4())]) - assert len(lookup) == 0 - - -def test_data_row_bulk_creation(dataset, rand_gen, image_url): - client = dataset.client - assert len(list(dataset.data_rows())) == 0 - - # Test creation using URL - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url - }, - { - "row_data": image_url - }, - ]) - assert task in client.get_user().created_tasks() - task.wait_till_done() - assert task.status == "COMPLETE" - - data_rows = list(dataset.data_rows()) - assert len(data_rows) == 2 - assert {data_row.row_data for data_row in data_rows} == {image_url} - assert {data_row.global_key for data_row in data_rows} == {None} - - data_rows = list(dataset.data_rows(from_cursor=data_rows[0].uid)) - assert len(data_rows) == 1 - - # Test creation using file name - with NamedTemporaryFile() as fp: - data = rand_gen(str).encode() - fp.write(data) - fp.flush() - task = dataset.create_data_rows([fp.name]) - task.wait_till_done() - assert task.status == "COMPLETE" - - task = dataset.create_data_rows([{ - "row_data": fp.name, - 'external_id': 'some_name' - }]) - task.wait_till_done() - assert task.status == "COMPLETE" - - task = dataset.create_data_rows([{"row_data": fp.name}]) - task.wait_till_done() - assert task.status == "COMPLETE" - - data_rows = list(dataset.data_rows()) - assert len(data_rows) == 5 - url = ({data_row.row_data for data_row in data_rows} - {image_url}).pop() - assert requests.get(url).content == data - - for dr in data_rows: - dr.delete() - - -@pytest.mark.slow -def test_data_row_large_bulk_creation(dataset, image_url): - # Do a longer task and expect it not to be complete immediately - n_urls = 1000 - n_local = 250 - with NamedTemporaryFile() as fp: - fp.write("Test data".encode()) - fp.flush() - task = dataset.create_data_rows([{ - DataRow.row_data: image_url - }] * n_urls + [fp.name] * n_local) - task.wait_till_done() - assert task.status == "COMPLETE" - assert len(list(dataset.data_rows())) == n_local + n_urls - - -def test_data_row_single_creation(dataset, rand_gen, image_url): - client = dataset.client - assert len(list(dataset.data_rows())) == 0 - - data_row = dataset.create_data_row(row_data=image_url) - assert len(list(dataset.data_rows())) == 1 - assert data_row.dataset() == dataset - assert data_row.created_by() == client.get_user() - assert data_row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(data_row.row_data).content - assert data_row.media_attributes is not None - assert data_row.global_key is None - - with NamedTemporaryFile() as fp: - data = rand_gen(str).encode() - fp.write(data) - fp.flush() - data_row_2 = dataset.create_data_row(row_data=fp.name) - assert len(list(dataset.data_rows())) == 2 - assert requests.get(data_row_2.row_data).content == data - - -def test_create_data_row_with_dict(dataset, image_url): - client = dataset.client - assert len(list(dataset.data_rows())) == 0 - dr = {"row_data": image_url} - data_row = dataset.create_data_row(dr) - assert len(list(dataset.data_rows())) == 1 - assert data_row.dataset() == dataset - assert data_row.created_by() == client.get_user() - assert data_row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(data_row.row_data).content - assert data_row.media_attributes is not None - - -def test_create_data_row_with_dict_containing_field(dataset, image_url): - client = dataset.client - assert len(list(dataset.data_rows())) == 0 - dr = {DataRow.row_data: image_url} - data_row = dataset.create_data_row(dr) - assert len(list(dataset.data_rows())) == 1 - assert data_row.dataset() == dataset - assert data_row.created_by() == client.get_user() - assert data_row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(data_row.row_data).content - assert data_row.media_attributes is not None - - -def test_create_data_row_with_dict_unpacked(dataset, image_url): - client = dataset.client - assert len(list(dataset.data_rows())) == 0 - dr = {"row_data": image_url} - data_row = dataset.create_data_row(**dr) - assert len(list(dataset.data_rows())) == 1 - assert data_row.dataset() == dataset - assert data_row.created_by() == client.get_user() - assert data_row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(data_row.row_data).content - assert data_row.media_attributes is not None - - -def test_create_data_row_with_invalid_input(dataset, image_url): - with pytest.raises(labelbox.exceptions.InvalidQueryError) as exc: - dataset.create_data_row("asdf") - - dr = {"row_data": image_url} - with pytest.raises(labelbox.exceptions.InvalidQueryError) as exc: - dataset.create_data_row(dr, row_data=image_url) - - -def test_create_data_row_with_metadata(mdo, dataset, image_url): - client = dataset.client - assert len(list(dataset.data_rows())) == 0 - - data_row = dataset.create_data_row(row_data=image_url, - metadata_fields=make_metadata_fields()) - - assert len(list(dataset.data_rows())) == 1 - assert data_row.dataset() == dataset - assert data_row.created_by() == client.get_user() - assert data_row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(data_row.row_data).content - assert data_row.media_attributes is not None - metadata_fields = data_row.metadata_fields - metadata = data_row.metadata - assert len(metadata_fields) == 3 - assert len(metadata) == 3 - assert [m["schemaId"] for m in metadata_fields - ].sort() == EXPECTED_METADATA_SCHEMA_IDS - for m in metadata: - assert mdo._parse_upsert(m) - - -def test_create_data_row_with_metadata_dict(mdo, dataset, image_url): - client = dataset.client - assert len(list(dataset.data_rows())) == 0 - - data_row = dataset.create_data_row( - row_data=image_url, metadata_fields=make_metadata_fields_dict()) - - assert len(list(dataset.data_rows())) == 1 - assert data_row.dataset() == dataset - assert data_row.created_by() == client.get_user() - assert data_row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(data_row.row_data).content - assert data_row.media_attributes is not None - metadata_fields = data_row.metadata_fields - metadata = data_row.metadata - assert len(metadata_fields) == 3 - assert len(metadata) == 3 - assert [m["schemaId"] for m in metadata_fields - ].sort() == EXPECTED_METADATA_SCHEMA_IDS - for m in metadata: - assert mdo._parse_upsert(m) - - -def test_create_data_row_with_invalid_metadata(dataset, image_url): - fields = make_metadata_fields() - # make the payload invalid by providing the same schema id more than once - fields.append( - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value='some msg')) - - with pytest.raises(labelbox.exceptions.MalformedQueryException): - dataset.create_data_row(row_data=image_url, metadata_fields=fields) - - -def test_create_data_rows_with_metadata(mdo, dataset, image_url): - client = dataset.client - assert len(list(dataset.data_rows())) == 0 - - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url, - DataRow.external_id: "row1", - DataRow.metadata_fields: make_metadata_fields() - }, - { - DataRow.row_data: image_url, - DataRow.external_id: "row2", - "metadata_fields": make_metadata_fields() - }, - { - DataRow.row_data: image_url, - DataRow.external_id: "row3", - DataRow.metadata_fields: make_metadata_fields_dict() - }, - { - DataRow.row_data: image_url, - DataRow.external_id: "row4", - "metadata_fields": make_metadata_fields_dict() - }, - ]) - task.wait_till_done() - - assert len(list(dataset.data_rows())) == 4 - for r in ["row1", "row2", "row3", "row4"]: - row = list(dataset.data_rows(where=DataRow.external_id == r))[0] - assert row.dataset() == dataset - assert row.created_by() == client.get_user() - assert row.organization() == client.get_organization() - assert requests.get(image_url).content == \ - requests.get(row.row_data).content - assert row.media_attributes is not None - - metadata_fields = row.metadata_fields - metadata = row.metadata - assert len(metadata_fields) == 3 - assert len(metadata) == 3 - assert [m["schemaId"] for m in metadata_fields - ].sort() == EXPECTED_METADATA_SCHEMA_IDS - for m in metadata: - assert mdo._parse_upsert(m) - - -@pytest.mark.parametrize("test_function,metadata_obj_type", - [("create_data_rows", "class"), - ("create_data_rows", "dict"), - ("create_data_rows_sync", "class"), - ("create_data_rows_sync", "dict"), - ("create_data_row", "class"), - ("create_data_row", "dict")]) -def test_create_data_rows_with_named_metadata_field_class( - test_function, metadata_obj_type, mdo, dataset, image_url): - - row_with_metadata_field = { - DataRow.row_data: - image_url, - DataRow.external_id: - "row1", - DataRow.metadata_fields: [ - DataRowMetadataField(name='split', value='test'), - DataRowMetadataField(name=CUSTOM_TEXT_SCHEMA_NAME, value='hello') - ] - } - - row_with_metadata_dict = { - DataRow.row_data: - image_url, - DataRow.external_id: - "row2", - "metadata_fields": [ - { - 'name': 'split', - 'value': 'test' - }, - { - 'name': CUSTOM_TEXT_SCHEMA_NAME, - 'value': 'hello' - }, - ] - } - - assert len(list(dataset.data_rows())) == 0 - - METADATA_FIELDS = { - "class": row_with_metadata_field, - "dict": row_with_metadata_dict - } - - def create_data_row(data_rows): - dataset.create_data_row(data_rows[0]) - - CREATION_FUNCTION = { - "create_data_rows": dataset.create_data_rows, - "create_data_rows_sync": dataset.create_data_rows_sync, - "create_data_row": create_data_row - } - data_rows = [METADATA_FIELDS[metadata_obj_type]] - function_to_test = CREATION_FUNCTION[test_function] - task = function_to_test(data_rows) - - if isinstance(task, Task): - task.wait_till_done() - - created_rows = list(dataset.data_rows()) - assert len(created_rows) == 1 - assert len(created_rows[0].metadata_fields) == 2 - assert len(created_rows[0].metadata) == 2 - - metadata = created_rows[0].metadata - assert metadata[0].schema_id == SPLIT_SCHEMA_ID - assert metadata[0].name == 'test' - assert metadata[0].value == mdo.reserved_by_name['split']['test'].uid - assert metadata[1].name == CUSTOM_TEXT_SCHEMA_NAME - assert metadata[1].value == 'hello' - assert metadata[1].schema_id == mdo.custom_by_name[ - CUSTOM_TEXT_SCHEMA_NAME].uid - - -def test_create_data_rows_with_invalid_metadata(dataset, image_url): - fields = make_metadata_fields() - # make the payload invalid by providing the same schema id more than once - fields.append( - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value="some msg")) - - task = dataset.create_data_rows([{ - DataRow.row_data: image_url, - DataRow.metadata_fields: fields - }]) - task.wait_till_done(timeout_seconds=60) - - assert task.status == "COMPLETE" - assert len(task.failed_data_rows) == 1 - assert f"A schemaId can only be specified once per DataRow : [{TEXT_SCHEMA_ID}]" in task.failed_data_rows[ - 0]["message"] - - -def test_create_data_rows_with_metadata_missing_value(dataset, image_url): - fields = make_metadata_fields() - fields.append({"schemaId": "some schema id"}) - - with pytest.raises(ValueError) as exc: - dataset.create_data_rows([ - { - DataRow.row_data: image_url, - DataRow.external_id: "row1", - DataRow.metadata_fields: fields - }, - ]) - - -def test_create_data_rows_with_metadata_missing_schema_id(dataset, image_url): - fields = make_metadata_fields() - fields.append({"value": "some value"}) - - with pytest.raises(ValueError) as exc: - dataset.create_data_rows([ - { - DataRow.row_data: image_url, - DataRow.external_id: "row1", - DataRow.metadata_fields: fields - }, - ]) - - -def test_create_data_rows_with_metadata_wrong_type(dataset, image_url): - fields = make_metadata_fields() - fields.append("Neither DataRowMetadataField or dict") - - with pytest.raises(ValueError) as exc: - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url, - DataRow.external_id: "row1", - DataRow.metadata_fields: fields - }, - ]) - - -def test_data_row_update_missing_or_empty_required_fields( - dataset, rand_gen, image_url): - external_id = rand_gen(str) - data_row = dataset.create_data_row(row_data=image_url, - external_id=external_id) - with pytest.raises(ValueError): - data_row.update(row_data="") - with pytest.raises(ValueError): - data_row.update(row_data={}) - with pytest.raises(ValueError): - data_row.update(external_id="") - with pytest.raises(ValueError): - data_row.update(global_key="") - with pytest.raises(ValueError): - data_row.update() - - -def test_data_row_update(client, dataset, rand_gen, image_url, - wait_for_data_row_processing): - external_id = rand_gen(str) - data_row = dataset.create_data_row(row_data=image_url, - external_id=external_id) - assert data_row.external_id == external_id - - external_id_2 = rand_gen(str) - data_row.update(external_id=external_id_2) - assert data_row.external_id == external_id_2 - - in_line_content = "123" - data_row.update(row_data=in_line_content) - assert data_row.row_data == in_line_content - - data_row.update(row_data=image_url) - data_row = wait_for_data_row_processing(client, data_row) - assert data_row.row_data == image_url - - # tileLayer becomes a media attribute - pdf_url = "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf" - tileLayerUrl = "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483-lb-textlayer.json" - data_row.update(row_data={'pdfUrl': pdf_url, "tileLayerUrl": tileLayerUrl}) - data_row = wait_for_data_row_processing(client, - data_row, - compare_with_prev_media_attrs=True) - assert data_row.row_data == pdf_url - - -def test_data_row_filtering_sorting(dataset, image_url): - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url, - DataRow.external_id: "row1" - }, - { - DataRow.row_data: image_url, - DataRow.external_id: "row2" - }, - ]) - task.wait_till_done() - - # Test filtering - row1 = list(dataset.data_rows(where=DataRow.external_id == "row1")) - assert len(row1) == 1 - row1 = dataset.data_rows_for_external_id("row1") - assert len(row1) == 1 - row1 = row1[0] - assert row1.external_id == "row1" - row2 = list(dataset.data_rows(where=DataRow.external_id == "row2")) - assert len(row2) == 1 - row2 = dataset.data_rows_for_external_id("row2") - assert len(row2) == 1 - row2 = row2[0] - assert row2.external_id == "row2" - - -@pytest.fixture -def create_datarows_for_data_row_deletion(dataset, image_url): - task = dataset.create_data_rows([{ - DataRow.row_data: image_url, - DataRow.external_id: str(i) - } for i in range(10)]) - task.wait_till_done() - - data_rows = list(dataset.data_rows()) - - yield data_rows - for dr in data_rows: - dr.delete() - - -def test_data_row_deletion(dataset, create_datarows_for_data_row_deletion): - create_datarows_for_data_row_deletion - data_rows = list(dataset.data_rows()) - expected = set(map(str, range(10))) - assert {dr.external_id for dr in data_rows} == expected - - for dr in data_rows: - if dr.external_id in "37": - dr.delete() - expected -= set("37") - - data_rows = list(dataset.data_rows()) - assert {dr.external_id for dr in data_rows} == expected - - DataRow.bulk_delete([dr for dr in data_rows if dr.external_id in "2458"]) - expected -= set("2458") - - data_rows = list(dataset.data_rows()) - assert {dr.external_id for dr in data_rows} == expected - - -def test_data_row_iteration(dataset, image_url) -> None: - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url - }, - { - "row_data": image_url - }, - ]) - task.wait_till_done() - assert next(dataset.data_rows()) - - -def test_data_row_attachments(dataset, image_url): - attachments = [("IMAGE", image_url, "attachment image"), - ("RAW_TEXT", "test-text", None), - ("IMAGE_OVERLAY", image_url, "Overlay"), - ("HTML", image_url, None)] - task = dataset.create_data_rows([{ - "row_data": - image_url, - "external_id": - "test-id", - "attachments": [{ - "type": attachment_type, - "value": attachment_value, - "name": attachment_name - }] - } for attachment_type, attachment_value, attachment_name in attachments]) - - task.wait_till_done() - assert task.status == "COMPLETE" - data_rows = list(dataset.data_rows()) - assert len(data_rows) == len(attachments) - for data_row in data_rows: - assert len(list(data_row.attachments())) == 1 - assert data_row.external_id == "test-id" - - with pytest.raises(ValueError) as exc: - task = dataset.create_data_rows([{ - "row_data": image_url, - "external_id": "test-id", - "attachments": [{ - "type": "INVALID", - "value": "123" - }] - }]) - - -def test_create_data_rows_sync_attachments(dataset, image_url): - attachments = [("IMAGE", image_url, "image URL"), - ("RAW_TEXT", "test-text", None), - ("IMAGE_OVERLAY", image_url, "Overlay"), - ("HTML", image_url, None)] - attachments_per_data_row = 3 - dataset.create_data_rows_sync([{ - "row_data": - image_url, - "external_id": - "test-id", - "attachments": [{ - "type": attachment_type, - "value": attachment_value, - "name": attachment_name - } for _ in range(attachments_per_data_row)] - } for attachment_type, attachment_value, attachment_name in attachments]) - data_rows = list(dataset.data_rows()) - assert len(data_rows) == len(attachments) - for data_row in data_rows: - assert len(list(data_row.attachments())) == attachments_per_data_row - - -def test_create_data_rows_sync_mixed_upload(dataset, image_url): - n_local = 100 - n_urls = 100 - with NamedTemporaryFile() as fp: - fp.write("Test data".encode()) - fp.flush() - dataset.create_data_rows_sync([{ - DataRow.row_data: image_url - }] * n_urls + [fp.name] * n_local) - assert len(list(dataset.data_rows())) == n_local + n_urls - - -def test_delete_data_row_attachment(data_row, image_url): - attachments = [] - - # Anonymous attachment - to_attach = [("IMAGE", image_url), ("RAW_TEXT", "test-text"), - ("IMAGE_OVERLAY", image_url), ("HTML", image_url)] - for attachment_type, attachment_value in to_attach: - attachments.append( - data_row.create_attachment(attachment_type, attachment_value)) - - # Attachment with a name - to_attach = [("IMAGE", image_url, "Att. Image"), - ("RAW_TEXT", "test-text", "Att. Text"), - ("IMAGE_OVERLAY", image_url, "Image Overlay"), - ("HTML", image_url, "Att. HTML")] - for attachment_type, attachment_value, attachment_name in to_attach: - attachments.append( - data_row.create_attachment(attachment_type, attachment_value, - attachment_name)) - - for attachment in attachments: - attachment.delete() - - assert len(list(data_row.attachments())) == 0 - - -def test_create_data_rows_result(client, dataset, image_url): - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url, - DataRow.external_id: "row1", - }, - { - DataRow.row_data: image_url, - DataRow.external_id: "row1", - }, - ]) - assert task.errors is None - for result in task.result: - client.get_data_row(result['id']) - - -def test_create_data_rows_local_file(dataset, sample_image): - task = dataset.create_data_rows([{ - DataRow.row_data: sample_image, - DataRow.metadata_fields: make_metadata_fields() - }]) - task.wait_till_done() - assert task.status == "COMPLETE" - data_row = list(dataset.data_rows())[0] - assert data_row.external_id == "tests/integration/media/sample_image.jpg" - assert len(data_row.metadata_fields) == 3 - - -def test_data_row_with_global_key(dataset, sample_image): - global_key = str(uuid.uuid4()) - row = dataset.create_data_row({ - DataRow.row_data: sample_image, - DataRow.global_key: global_key - }) - - assert row.global_key == global_key - - -def test_data_row_bulk_creation_with_unique_global_keys(dataset, sample_image): - global_key_1 = str(uuid.uuid4()) - global_key_2 = str(uuid.uuid4()) - global_key_3 = str(uuid.uuid4()) - - task = dataset.create_data_rows([ - { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 - }, - { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_2 - }, - { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_3 - }, - ]) - - task.wait_till_done() - assert {row.global_key for row in dataset.data_rows() - } == {global_key_1, global_key_2, global_key_3} - - -def test_data_row_bulk_creation_with_same_global_keys(dataset, sample_image, - snapshot): - global_key_1 = str(uuid.uuid4()) - task = dataset.create_data_rows([{ - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 - }, { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 - }]) - - task.wait_till_done() - - assert task.status == "COMPLETE" - assert type(task.failed_data_rows) is list - assert len(task.failed_data_rows) == 1 - assert type(task.created_data_rows) is list - assert len(task.created_data_rows) == 1 - assert task.failed_data_rows[0][ - 'message'] == f"Duplicate global key: '{global_key_1}'" - assert task.failed_data_rows[0]['failedDataRows'][0][ - 'externalId'] == sample_image - assert task.created_data_rows[0]['externalId'] == sample_image - assert task.created_data_rows[0]['globalKey'] == global_key_1 - - -def test_data_row_delete_and_create_with_same_global_key( - client, dataset, sample_image): - global_key_1 = str(uuid.uuid4()) - data_row_payload = { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 - } - - # should successfully insert new datarow - task = dataset.create_data_rows([data_row_payload]) - task.wait_till_done() - - assert task.status == "COMPLETE" - assert task.result[0]['global_key'] == global_key_1 - - new_data_row_id = task.result[0]['id'] - - # same payload should fail due to duplicated global key - task = dataset.create_data_rows([data_row_payload]) - task.wait_till_done() - - assert task.status == "COMPLETE" - assert len(task.failed_data_rows) == 1 - assert task.failed_data_rows[0][ - 'message'] == f"Duplicate global key: '{global_key_1}'" - - # delete datarow - client.get_data_row(new_data_row_id).delete() - - # should successfully insert new datarow now - task = dataset.create_data_rows([data_row_payload]) - task.wait_till_done() - - assert task.status == "COMPLETE" - assert task.result[0]['global_key'] == global_key_1 - - -def test_data_row_bulk_creation_sync_with_unique_global_keys( - dataset, sample_image): - global_key_1 = str(uuid.uuid4()) - global_key_2 = str(uuid.uuid4()) - global_key_3 = str(uuid.uuid4()) - - dataset.create_data_rows_sync([ - { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 - }, - { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_2 - }, - { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_3 - }, - ]) - - assert {row.global_key for row in dataset.data_rows() - } == {global_key_1, global_key_2, global_key_3} - - -def test_data_row_bulk_creation_sync_with_same_global_keys( - dataset, sample_image): - global_key_1 = str(uuid.uuid4()) - - with pytest.raises(labelbox.exceptions.MalformedQueryException) as exc_info: - dataset.create_data_rows_sync([{ - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 - }, { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1 - }]) - - assert len(list(dataset.data_rows())) == 1 - assert list(dataset.data_rows())[0].global_key == global_key_1 - assert "Some data rows were not imported. Check error output here" in str( - exc_info.value) - - -@pytest.fixture -def converstational_data_rows(dataset, conversational_content): - examples = [ - { - **conversational_content, 'media_type': - MediaType.Conversational.value - }, - conversational_content, - { - "conversationalData": conversational_content['row_data']['messages'] - } # Old way to check for backwards compatibility - ] - task = dataset.create_data_rows(examples) - task.wait_till_done() - assert task.status == "COMPLETE" - - data_rows = list(dataset.data_rows()) - - yield data_rows - for dr in data_rows: - dr.delete() - - -def test_create_conversational_text(converstational_data_rows, - conversational_content): - data_rows = converstational_data_rows - for data_row in data_rows: - assert requests.get( - data_row.row_data).json() == conversational_content['row_data'] - - -def test_invalid_media_type(dataset, conversational_content): - for _, __ in [["Found invalid contents for media type: 'IMAGE'", 'IMAGE'], - [ - "Found invalid media type: 'totallyinvalid'", - 'totallyinvalid' - ]]: - # TODO: What error kind should this be? It looks like for global key we are - # using malformed query. But for invalid contents in FileUploads we use InvalidQueryError - with pytest.raises(labelbox.exceptions.InvalidQueryError): - dataset.create_data_rows_sync([{ - **conversational_content, 'media_type': 'IMAGE' - }]) - - -def test_create_tiled_layer(dataset, tile_content): - examples = [ - { - **tile_content, 'media_type': 'TMS_GEO' - }, - tile_content, - # Old way to check for backwards compatibility - tile_content['row_data'] - ] - dataset.create_data_rows_sync(examples) - data_rows = list(dataset.data_rows()) - assert len(data_rows) == len(examples) - for data_row in data_rows: - assert json.loads(data_row.row_data) == tile_content['row_data'] - - -def test_create_data_row_with_attachments(dataset): - attachment_value = 'attachment value' - dr = dataset.create_data_row(row_data="123", - attachments=[{ - 'type': 'RAW_TEXT', - 'value': attachment_value - }]) - attachments = list(dr.attachments()) - assert len(attachments) == 1 - - -def test_create_data_row_with_media_type(dataset, image_url): - with pytest.raises(labelbox.exceptions.InvalidQueryError) as exc: - dr = dataset.create_data_row( - row_data={'invalid_object': 'invalid_value'}, media_type="IMAGE") - - assert "Media type validation failed, expected: 'image/*', was: application/json" in str( - exc.value) - - dataset.create_data_row(row_data=image_url, media_type="IMAGE") - ----- -tests/integration/test_user_management.py -from labelbox import ProjectRole -import pytest - - -def test_org_invite(client, organization, environ, queries): - role = client.get_roles()['LABELER'] - dummy_email = "none@labelbox.com" - invite_limit = organization.invite_limit() - - if environ.value == "prod": - assert invite_limit.remaining > 0, "No invites available for the account associated with this key." - elif environ.value != "staging": - # Cannot run against local - return - - invite = organization.invite_user(dummy_email, role) - - if environ.value == "prod": - - invite_limit_after = organization.invite_limit() - # One user added - assert invite_limit.remaining - invite_limit_after.remaining == 1 - # An invite shouldn't effect the user count until after it is accepted - - outstanding_invites = queries.get_invites(client) - in_list = False - - for invite in outstanding_invites: - if invite.uid == invite.uid: - in_list = True - org_role = invite.organization_role_name.lower() - assert org_role == role.name.lower( - ), "Role should be labeler. Found {org_role} " - assert in_list, "Invite not found" - queries.cancel_invite(client, invite.uid) - assert invite_limit.remaining - organization.invite_limit().remaining == 0 - - -def test_project_invite(client, organization, project_pack, queries): - project_1, project_2 = project_pack - roles = client.get_roles() - dummy_email = "none1@labelbox.com" - project_role_1 = ProjectRole(project=project_1, role=roles['LABELER']) - project_role_2 = ProjectRole(project=project_2, role=roles['REVIEWER']) - invite = organization.invite_user( - dummy_email, - roles['NONE'], - project_roles=[project_role_1, project_role_2]) - - project_invite = next(queries.get_project_invites(client, project_1.uid)) - - assert set([(proj_invite.project.uid, proj_invite.role.uid) - for proj_invite in project_invite.project_roles - ]) == set([(proj_role.project.uid, proj_role.role.uid) - for proj_role in [project_role_1, project_role_2]]) - - assert set([(proj_invite.project.uid, proj_invite.role.uid) - for proj_invite in project_invite.project_roles - ]) == set([(proj_role.project.uid, proj_role.role.uid) - for proj_role in [project_role_1, project_role_2]]) - - project_members = project_1.members() - - project_member = [ - member for member in project_members - if member.user().uid == client.get_user().uid - ] - - assert len(project_member) == 1 - project_member = project_member[0] - - assert project_member.access_from == 'ORGANIZATION' - assert project_member.role().name.upper() == roles['ADMIN'].name.upper() - queries.cancel_invite(client, invite.uid) - - -@pytest.mark.skip( - "Unable to programatically create user without accepting an email invite. Add back once there is a workaround." -) -def test_member_management(client, organization, project, project_based_user): - roles = client.get_roles() - assert not len(list(project_based_user.projects())) - for role in [roles['LABELER'], roles['REVIEWER']]: - - project_based_user.upsert_project_role(project, role=role) - members = project.members() - is_member = False - for member in members: - if member.user().uid == project_based_user.uid: - is_member = True - assert member.role().name.upper() == role.name.upper() - break - assert is_member - - project_based_user.remove_from_project(project) - is_member = False - for member in project.members(): - assert member.user().uid != project_based_user.uid - - assert project_based_user.org_role().name.upper( - ) == roles['NONE'].name.upper() - for role in [ - roles['TEAM_MANAGER'], roles['ADMIN'], roles['LABELER'], - roles['REVIEWER'] - ]: - project_based_user.update_org_role(role) - project_based_user.org_role().name.upper() == role.name.upper() - - organization.remove_user(project_based_user) - for user in organization.users(): - assert project_based_user.uid != user.uid - ----- -tests/integration/test_ephemeral.py -import os -import pytest -from support.integration_client import Environ, EphemeralClient, IntegrationClient - - -@pytest.mark.skipif( - not os.environ.get('LABELBOX_TEST_ENVIRON') == Environ.EPHEMERAL.value, - reason='This test only runs in EPHEMERAL environment') -def test_org_and_user_setup(client): - assert type(client) == EphemeralClient - assert client.admin_client - assert client.api_key != client.admin_client.api_key - - organization = client.get_organization() - assert organization - user = client.get_user() - assert user - - -@pytest.mark.skipif( - os.environ.get('LABELBOX_TEST_ENVIRON') == Environ.EPHEMERAL.value, - reason='This test does not run in EPHEMERAL environment') -def test_integration_client(client): - assert type(client) == IntegrationClient - ----- -tests/integration/test_data_upload.py -import pytest -import requests - - -def test_file_upload(client, rand_gen, dataset): - data = rand_gen(str) - uri = client.upload_data(data.encode()) - data_row = dataset.create_data_row(row_data=uri) - assert requests.get(data_row.row_data).text == data - ----- -tests/integration/test_feature_schema.py -import pytest - -from labelbox import Tool, MediaType - -point = Tool( - tool=Tool.Type.POINT, - name="name", - color="#ff0000", -) - - -def test_deletes_a_feature_schema(client): - tool = client.upsert_feature_schema(point.asdict()) - - assert client.delete_unused_feature_schema( - tool.normalized['featureSchemaId']) is None - - -def test_cant_delete_already_deleted_feature_schema(client): - tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] - - client.delete_unused_feature_schema(feature_schema_id) is None - - with pytest.raises( - Exception, - match= - "Failed to delete the feature schema, message: Feature schema is already deleted" - ): - client.delete_unused_feature_schema(feature_schema_id) - - -def test_cant_delete_feature_schema_with_ontology(client): - tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] - ontology = client.create_ontology_from_feature_schemas( - name='ontology name', - feature_schema_ids=[feature_schema_id], - media_type=MediaType.Image) - - with pytest.raises( - Exception, - match= - "Failed to delete the feature schema, message: Feature schema cannot be deleted because it is used in ontologies" - ): - client.delete_unused_feature_schema(feature_schema_id) - - client.delete_unused_ontology(ontology.uid) - client.delete_unused_feature_schema(feature_schema_id) - - -def test_throws_an_error_if_feature_schema_to_delete_doesnt_exist(client): - with pytest.raises( - Exception, - match= - "Failed to delete the feature schema, message: Cannot find root schema node with feature schema id doesntexist" - ): - client.delete_unused_feature_schema("doesntexist") - - -def test_updates_a_feature_schema_title(client): - tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] - new_title = "new title" - updated_feature_schema = client.update_feature_schema_title( - feature_schema_id, new_title) - - assert updated_feature_schema.normalized['name'] == new_title - - client.delete_unused_feature_schema(feature_schema_id) - - -def test_throws_an_error_when_updating_a_feature_schema_with_empty_title( - client): - tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] - - with pytest.raises(Exception): - client.update_feature_schema_title(feature_schema_id, "") - - client.delete_unused_feature_schema(feature_schema_id) - - -def test_throws_an_error_when_updating_not_existing_feature_schema(client): - with pytest.raises(Exception): - client.update_feature_schema_title("doesntexist", "new title") - - -def test_creates_a_new_feature_schema(client): - created_feature_schema = client.upsert_feature_schema(point.asdict()) - - assert created_feature_schema.uid is not None - - client.delete_unused_feature_schema( - created_feature_schema.normalized['featureSchemaId']) - - -def test_updates_a_feature_schema(client): - tool = Tool( - tool=Tool.Type.POINT, - name="name", - color="#ff0000", - ) - created_feature_schema = client.upsert_feature_schema(tool.asdict()) - tool_to_update = Tool( - tool=Tool.Type.POINT, - name="new name", - color="#ff0000", - feature_schema_id=created_feature_schema.normalized['featureSchemaId'], - ) - updated_feature_schema = client.upsert_feature_schema( - tool_to_update.asdict()) - - assert updated_feature_schema.normalized['name'] == "new name" - - -def test_does_not_include_used_feature_schema(client): - tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] - ontology = client.create_ontology_from_feature_schemas( - name='ontology name', - feature_schema_ids=[feature_schema_id], - media_type=MediaType.Image) - unused_feature_schemas = client.get_unused_feature_schemas() - - assert feature_schema_id not in unused_feature_schemas - - client.delete_unused_ontology(ontology.uid) - client.delete_unused_feature_schema(feature_schema_id) - ----- -tests/integration/test_benchmark.py -def test_benchmark(configured_project_with_label): - project, _, data_row, label = configured_project_with_label - assert set(project.benchmarks()) == set() - assert label.is_benchmark_reference == False - - benchmark = label.create_benchmark() - assert set(project.benchmarks()) == {benchmark} - assert benchmark.reference_label() == label - # Refresh label data to check it's benchmark reference - label = list(data_row.labels())[0] - assert label.is_benchmark_reference == True - - benchmark.delete() - assert set(project.benchmarks()) == set() - # Refresh label data to check it's benchmark reference - label = list(data_row.labels())[0] - assert label.is_benchmark_reference == False - ----- -tests/integration/test_toggle_mal.py -def test_enable_model_assisted_labeling(project): - response = project.enable_model_assisted_labeling() - assert response == True - - response = project.enable_model_assisted_labeling(True) - assert response == True - - response = project.enable_model_assisted_labeling(False) - assert response == False - ----- -tests/integration/test_dataset.py -import pytest -import requests -from labelbox import Dataset -from labelbox.exceptions import ResourceNotFoundError, MalformedQueryException, InvalidQueryError -from labelbox.schema.dataset import MAX_DATAROW_PER_API_OPERATION - - -def test_dataset(client, rand_gen): - - # confirm dataset can be created - name = rand_gen(str) - dataset = client.create_dataset(name=name) - assert dataset.name == name - assert dataset.created_by() == client.get_user() - assert dataset.organization() == client.get_organization() - - retrieved_dataset = client.get_dataset(dataset.uid) - assert retrieved_dataset.name == dataset.name - assert retrieved_dataset.uid == dataset.uid - assert retrieved_dataset.created_by() == dataset.created_by() - assert retrieved_dataset.organization() == dataset.organization() - - dataset = client.get_dataset(dataset.uid) - assert dataset.name == name - - new_name = rand_gen(str) - dataset.update(name=new_name) - # Test local object updates. - assert dataset.name == new_name - - # Test remote updates. - dataset = client.get_dataset(dataset.uid) - assert dataset.name == new_name - - # Test description - description = rand_gen(str) - assert dataset.description == "" - dataset.update(description=description) - assert dataset.description == description - - dataset.delete() - - with pytest.raises(ResourceNotFoundError): - dataset = client.get_dataset(dataset.uid) - - -@pytest.fixture -def dataset_for_filtering(client, rand_gen): - name_1 = rand_gen(str) - name_2 = rand_gen(str) - d1 = client.create_dataset(name=name_1) - d2 = client.create_dataset(name=name_2) - - yield name_1, d1, name_2, d2 - - -def test_dataset_filtering(client, dataset_for_filtering): - name_1, d1, name_2, d2 = dataset_for_filtering - - assert list(client.get_datasets(where=Dataset.name == name_1)) == [d1] - assert list(client.get_datasets(where=Dataset.name == name_2)) == [d2] - - -def test_dataset_filtering(client, dataset_for_filtering): - name_1, d1, name_2, d2 = dataset_for_filtering - - assert list(client.get_datasets(where=Dataset.name == name_1)) == [d1] - assert list(client.get_datasets(where=Dataset.name == name_2)) == [d2] - - -def test_get_data_row_for_external_id(dataset, rand_gen, image_url): - external_id = rand_gen(str) - - with pytest.raises(ResourceNotFoundError): - data_row = dataset.data_row_for_external_id(external_id) - - data_row = dataset.create_data_row(row_data=image_url, - external_id=external_id) - - found = dataset.data_row_for_external_id(external_id) - assert found.uid == data_row.uid - assert found.external_id == external_id - - dataset.create_data_row(row_data=image_url, external_id=external_id) - assert len(dataset.data_rows_for_external_id(external_id)) == 2 - - task = dataset.create_data_rows( - [dict(row_data=image_url, external_id=external_id)]) - task.wait_till_done() - assert len(dataset.data_rows_for_external_id(external_id)) == 3 - - -def test_upload_video_file(dataset, sample_video: str) -> None: - """ - Tests that a mp4 video can be uploaded and preserve content length - and content type. - - """ - dataset.create_data_row(row_data=sample_video) - task = dataset.create_data_rows([sample_video, sample_video]) - task.wait_till_done() - - with open(sample_video, 'rb') as video_f: - content_length = len(video_f.read()) - - for data_row in dataset.data_rows(): - url = data_row.row_data - response = requests.head(url, allow_redirects=True) - assert int(response.headers['Content-Length']) == content_length - assert response.headers['Content-Type'] == 'video/mp4' - - -def test_create_pdf(dataset): - dataset.create_data_row( - row_data={ - "pdfUrl": - "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-1.pdf", - "textLayerUrl": - "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-custom-text-layer.json" - }) - dataset.create_data_row(row_data={ - "pdfUrl": - "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-1.pdf", - "textLayerUrl": - "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-custom-text-layer.json" - }, - media_type="PDF") - - with pytest.raises(InvalidQueryError): - # Wrong media type - dataset.create_data_row(row_data={ - "pdfUrl": - "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-1.pdf", - "textLayerUrl": - "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/sample-document-custom-text-layer.json" - }, - media_type="TEXT") - - -def test_bulk_conversation(dataset, sample_bulk_conversation: list) -> None: - """ - Tests that bulk conversations can be uploaded. - - """ - task = dataset.create_data_rows(sample_bulk_conversation) - task.wait_till_done() - - assert len(list(dataset.data_rows())) == len(sample_bulk_conversation) - - -def test_create_descriptor_file(dataset): - import unittest.mock as mock - with mock.patch.object(dataset.client, - 'upload_data', - wraps=dataset.client.upload_data) as upload_data_spy: - dataset._create_descriptor_file(items=[{'row_data': 'some text...'}]) - upload_data_spy.assert_called() - call_args, call_kwargs = upload_data_spy.call_args_list[0][ - 0], upload_data_spy.call_args_list[0][1] - assert call_args == ('[{"row_data": "some text..."}]',) - assert call_kwargs == { - 'content_type': 'application/json', - 'filename': 'json_import.json' - } - - -def test_max_dataset_datarow_upload(dataset, image_url, rand_gen): - external_id = str(rand_gen) - items = [dict(row_data=image_url, external_id=external_id) - ] * (MAX_DATAROW_PER_API_OPERATION + 1) - - with pytest.raises(MalformedQueryException): - dataset.create_data_rows(items) - ----- -tests/integration/test_legacy_project.py -import pytest - -from labelbox.schema.queue_mode import QueueMode - - -def test_project_dataset(client, rand_gen): - with pytest.raises( - ValueError, - match= - "Dataset queue mode is deprecated. Please prefer Batch queue mode." - ): - client.create_project( - name=rand_gen(str), - queue_mode=QueueMode.Dataset, - ) - - -def test_project_auto_audit_parameters(client, rand_gen): - with pytest.raises( - ValueError, - match= - "quality_mode must be set instead of auto_audit_percentage or auto_audit_number_of_labels." - ): - client.create_project(name=rand_gen(str), auto_audit_percentage=0.5) - - with pytest.raises( - ValueError, - match= - "quality_mode must be set instead of auto_audit_percentage or auto_audit_number_of_labels." - ): - client.create_project(name=rand_gen(str), auto_audit_number_of_labels=2) - - -def test_project_name_parameter(client, rand_gen): - with pytest.raises(ValueError, - match="project name must be a valid string."): - client.create_project() - - with pytest.raises(ValueError, - match="project name must be a valid string."): - client.create_project(name=" ") - ----- -tests/integration/test_data_row_metadata.py -from datetime import datetime - -import pytest -import uuid - -from labelbox import Dataset -from labelbox.exceptions import MalformedQueryException -from labelbox.schema.identifiables import GlobalKeys, UniqueIds -from labelbox.schema.data_row_metadata import DataRowMetadataField, DataRowMetadata, DataRowMetadataKind, DataRowMetadataOntology, _parse_metadata_schema - -INVALID_SCHEMA_ID = "1" * 25 -FAKE_SCHEMA_ID = "0" * 25 -FAKE_DATAROW_ID = "D" * 25 -SPLIT_SCHEMA_ID = "cko8sbczn0002h2dkdaxb5kal" -TRAIN_SPLIT_ID = "cko8sbscr0003h2dk04w86hof" -TEST_SPLIT_ID = "cko8scbz70005h2dkastwhgqt" -TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh" -CAPTURE_DT_SCHEMA_ID = "cko8sdzv70006h2dk8jg64zvb" -CUSTOM_TEXT_SCHEMA_NAME = 'custom_text' - -FAKE_NUMBER_FIELD = { - "id": FAKE_SCHEMA_ID, - "name": "number", - "kind": 'CustomMetadataNumber', - "reserved": False -} - - -@pytest.fixture -def mdo(client): - mdo = client.get_data_row_metadata_ontology() - try: - mdo.create_schema(CUSTOM_TEXT_SCHEMA_NAME, DataRowMetadataKind.string) - except MalformedQueryException: - # Do nothing if already exists - pass - mdo._raw_ontology = mdo._get_ontology() - mdo._raw_ontology.append(FAKE_NUMBER_FIELD) - mdo._build_ontology() - yield mdo - - -@pytest.fixture -def big_dataset(dataset: Dataset, image_url): - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image" - }, - ] * 5) - task.wait_till_done() - - yield dataset - - -def make_metadata(dr_id: str = None, gk: str = None) -> DataRowMetadata: - msg = "A message" - time = datetime.utcnow() - - metadata = DataRowMetadata( - global_key=gk, - data_row_id=dr_id, - fields=[ - DataRowMetadataField(schema_id=SPLIT_SCHEMA_ID, - value=TEST_SPLIT_ID), - DataRowMetadataField(schema_id=CAPTURE_DT_SCHEMA_ID, value=time), - DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value=msg), - ]) - return metadata - - -def make_named_metadata(dr_id) -> DataRowMetadata: - msg = "A message" - time = datetime.utcnow() - - metadata = DataRowMetadata(data_row_id=dr_id, - fields=[ - DataRowMetadataField(name='split', - value=TEST_SPLIT_ID), - DataRowMetadataField(name='captureDateTime', - value=time), - DataRowMetadataField( - name=CUSTOM_TEXT_SCHEMA_NAME, value=msg), - ]) - return metadata - - -def test_export_empty_metadata(client, configured_project_with_label, - wait_for_data_row_processing): - project, _, data_row, _ = configured_project_with_label - data_row = wait_for_data_row_processing(client, data_row) - labels = project.label_generator() - label = next(labels) - assert label.data.metadata == [] - - -def test_bulk_export_datarow_metadata(data_row, mdo: DataRowMetadataOntology): - metadata = make_metadata(data_row.uid) - mdo.bulk_upsert([metadata]) - exported = mdo.bulk_export([data_row.uid]) - assert exported[0].global_key == data_row.global_key - assert exported[0].data_row_id == data_row.uid - assert len([field for field in exported[0].fields]) == 3 - - exported = mdo.bulk_export(UniqueIds([data_row.uid])) - assert exported[0].global_key == data_row.global_key - assert exported[0].data_row_id == data_row.uid - assert len([field for field in exported[0].fields]) == 3 - - exported = mdo.bulk_export(GlobalKeys([data_row.global_key])) - assert exported[0].global_key == data_row.global_key - assert exported[0].data_row_id == data_row.uid - assert len([field for field in exported[0].fields]) == 3 - - -def test_get_datarow_metadata_ontology(mdo): - assert len(mdo.fields) - assert len(mdo.reserved_fields) - # two are created by mdo fixture but there may be more - assert len(mdo.custom_fields) >= 2 - - split = mdo.reserved_by_name["split"]["train"] - - assert DataRowMetadata( - data_row_id=FAKE_DATAROW_ID, - fields=[ - DataRowMetadataField( - schema_id=mdo.reserved_by_name["captureDateTime"].uid, - value=datetime.utcnow(), - ), - DataRowMetadataField(schema_id=split.parent, value=split.uid), - DataRowMetadataField(schema_id=mdo.reserved_by_name["tag"].uid, - value="hello-world"), - ]) - - -def test_bulk_upsert_datarow_metadata(data_row, mdo: DataRowMetadataOntology): - metadata = make_metadata(data_row.uid) - mdo.bulk_upsert([metadata]) - exported = mdo.bulk_export([data_row.uid]) - assert len(exported) - assert len([field for field in exported[0].fields]) == 3 - - -def test_bulk_upsert_datarow_metadata_by_globalkey( - data_rows, mdo: DataRowMetadataOntology): - global_keys = [data_row.global_key for data_row in data_rows] - metadata = [make_metadata(gk=global_key) for global_key in global_keys] - errors = mdo.bulk_upsert(metadata) - assert len(errors) == 0 - - -@pytest.mark.slow -def test_large_bulk_upsert_datarow_metadata(big_dataset, mdo): - metadata = [] - data_row_ids = [dr.uid for dr in big_dataset.data_rows()] - for data_row_id in data_row_ids: - metadata.append(make_metadata(data_row_id)) - errors = mdo.bulk_upsert(metadata) - assert len(errors) == 0 - - metadata_lookup = { - metadata.data_row_id: metadata - for metadata in mdo.bulk_export(data_row_ids) - } - for data_row_id in data_row_ids: - assert len([f for f in metadata_lookup.get(data_row_id).fields - ]), metadata_lookup.get(data_row_id).fields - - -def test_upsert_datarow_metadata_by_name(data_row, mdo): - metadata = [make_named_metadata(data_row.uid)] - errors = mdo.bulk_upsert(metadata) - assert len(errors) == 0 - - metadata_lookup = { - metadata.data_row_id: metadata - for metadata in mdo.bulk_export([data_row.uid]) - } - assert len([f for f in metadata_lookup.get(data_row.uid).fields - ]), metadata_lookup.get(data_row.uid).fields - - -def test_upsert_datarow_metadata_option_by_name(data_row, mdo): - metadata = DataRowMetadata(data_row_id=data_row.uid, - fields=[ - DataRowMetadataField(name='split', - value='test'), - ]) - errors = mdo.bulk_upsert([metadata]) - assert len(errors) == 0 - - datarows = mdo.bulk_export([data_row.uid]) - assert len(datarows[0].fields) == 1 - metadata = datarows[0].fields[0] - assert metadata.schema_id == SPLIT_SCHEMA_ID - assert metadata.name == 'test' - assert metadata.value == TEST_SPLIT_ID - - -def test_upsert_datarow_metadata_option_by_incorrect_name(data_row, mdo): - metadata = DataRowMetadata(data_row_id=data_row.uid, - fields=[ - DataRowMetadataField(name='split', - value='test1'), - ]) - with pytest.raises(KeyError): - mdo.bulk_upsert([metadata]) - - -def test_raise_enum_upsert_schema_error(data_row, mdo): - """Setting an option id as the schema id will raise a Value Error""" - - metadata = DataRowMetadata(data_row_id=data_row.uid, - fields=[ - DataRowMetadataField(schema_id=TEST_SPLIT_ID, - value=SPLIT_SCHEMA_ID), - ]) - with pytest.raises(ValueError): - mdo.bulk_upsert([metadata]) - - -def test_upsert_non_existent_schema_id(data_row, mdo): - """Raise error on non-existent schema id""" - metadata = DataRowMetadata(data_row_id=data_row.uid, - fields=[ - DataRowMetadataField( - schema_id=INVALID_SCHEMA_ID, - value="message"), - ]) - with pytest.raises(ValueError): - mdo.bulk_upsert([metadata]) - - -def test_parse_raw_metadata(mdo): - example = { - 'dataRowId': - 'ckr6kkfx801ui0yrtg9fje8xh', - 'globalKey': - 'global-key-1', - 'fields': [ - { - 'schemaId': 'cko8s9r5v0001h2dk9elqdidh', - 'value': 'my-new-message' - }, - { - 'schemaId': 'cko8sbczn0002h2dkdaxb5kal', - 'value': {} - }, - { - 'schemaId': 'cko8sbscr0003h2dk04w86hof', - 'value': {} - }, - { - 'schemaId': 'cko8sdzv70006h2dk8jg64zvb', - 'value': '2021-07-20T21:41:14.606710Z' - }, - { - 'schemaId': FAKE_SCHEMA_ID, - 'value': 0.5 - }, - ] - } - - parsed = mdo.parse_metadata([example]) - assert len(parsed) == 1 - for row in parsed: - assert row.data_row_id == example["dataRowId"] - assert row.global_key == example["globalKey"] - assert len(row.fields) == 4 - - for row in parsed: - for field in row.fields: - assert mdo._parse_upsert(field) - - -def test_parse_raw_metadata_fields(mdo): - example = [ - { - 'schemaId': 'cko8s9r5v0001h2dk9elqdidh', - 'value': 'my-new-message' - }, - { - 'schemaId': 'cko8sbczn0002h2dkdaxb5kal', - 'value': {} - }, - { - 'schemaId': 'cko8sbscr0003h2dk04w86hof', - 'value': {} - }, - { - 'schemaId': 'cko8sdzv70006h2dk8jg64zvb', - 'value': '2021-07-20T21:41:14.606710Z' - }, - { - 'schemaId': FAKE_SCHEMA_ID, - 'value': 0.5 - }, - ] - - parsed = mdo.parse_metadata_fields(example) - assert len(parsed) == 4 - - for field in parsed: - assert mdo._parse_upsert(field) - - -def test_parse_metadata_schema(): - unparsed = { - 'id': - 'cl467a4ec0046076g7s9yheoa', - 'name': - 'enum metadata', - 'kind': - 'CustomMetadataEnum', - 'options': [{ - 'id': 'cl467a4ec0047076ggjneeruy', - 'name': 'option1', - 'kind': 'CustomMetadataEnumOption' - }, { - 'id': 'cl4qa31u0009e078p5m280jer', - 'name': 'option2', - 'kind': 'CustomMetadataEnumOption' - }] - } - parsed = _parse_metadata_schema(unparsed) - assert parsed.uid == 'cl467a4ec0046076g7s9yheoa' - assert parsed.name == 'enum metadata' - assert parsed.kind == DataRowMetadataKind.enum - assert len(parsed.options) == 2 - assert parsed.options[0].uid == 'cl467a4ec0047076ggjneeruy' - assert parsed.options[0].kind == DataRowMetadataKind.option - - -def test_create_schema(mdo): - metadata_name = str(uuid.uuid4()) - created_schema = mdo.create_schema(metadata_name, DataRowMetadataKind.enum, - ["option 1", "option 2"]) - assert created_schema.name == metadata_name - assert created_schema.kind == DataRowMetadataKind.enum - assert len(created_schema.options) == 2 - assert created_schema.options[0].name == "option 1" - mdo.delete_schema(metadata_name) - - -def test_update_schema(mdo): - metadata_name = str(uuid.uuid4()) - created_schema = mdo.create_schema(metadata_name, DataRowMetadataKind.enum, - ["option 1", "option 2"]) - updated_schema = mdo.update_schema(metadata_name, - f"{metadata_name}_updated") - assert updated_schema.name == f"{metadata_name}_updated" - assert updated_schema.uid == created_schema.uid - assert updated_schema.kind == DataRowMetadataKind.enum - mdo.delete_schema(f"{metadata_name}_updated") - - -def test_update_enum_options(mdo): - metadata_name = str(uuid.uuid4()) - created_schema = mdo.create_schema(metadata_name, DataRowMetadataKind.enum, - ["option 1", "option 2"]) - updated_schema = mdo.update_enum_option(metadata_name, "option 1", - "option 3") - assert updated_schema.name == metadata_name - assert updated_schema.uid == created_schema.uid - assert updated_schema.kind == DataRowMetadataKind.enum - assert updated_schema.options[0].uid == created_schema.options[0].uid - assert updated_schema.options[0].name == "option 3" - mdo.delete_schema(metadata_name) - - -def test_delete_schema(mdo): - metadata_name = str(uuid.uuid4()) - created_schema = mdo.create_schema(metadata_name, - DataRowMetadataKind.string) - status = mdo.delete_schema(created_schema.name) - mdo.refresh_ontology() - assert status - assert metadata_name not in mdo.custom_by_name - - -@pytest.mark.parametrize('datetime_str', - ['2011-11-04T00:05:23Z', '2011-11-04T00:05:23+00:00']) -def test_upsert_datarow_date_metadata(data_row, mdo, datetime_str): - metadata = [ - DataRowMetadata(data_row_id=data_row.uid, - fields=[ - DataRowMetadataField(name='captureDateTime', - value=datetime_str), - ]) - ] - errors = mdo.bulk_upsert(metadata) - assert len(errors) == 0 - - metadata = mdo.bulk_export([data_row.uid]) - assert f"{metadata[0].fields[0].value}" == "2011-11-04 00:05:23+00:00" - - -@pytest.mark.parametrize('datetime_str', - ['2011-11-04T00:05:23Z', '2011-11-04T00:05:23+00:00']) -def test_create_data_row_with_metadata(dataset, image_url, datetime_str): - client = dataset.client - assert len(list(dataset.data_rows())) == 0 - - metadata_fields = [ - DataRowMetadataField(name='captureDateTime', value=datetime_str) - ] - - data_row = dataset.create_data_row(row_data=image_url, - metadata_fields=metadata_fields) - - retrieved_data_row = client.get_data_row(data_row.uid) - assert f"{retrieved_data_row.metadata[0].value}" == "2011-11-04 00:05:23+00:00" - ----- -tests/integration/test_foundry.py -import labelbox as lb -import pytest -from labelbox.schema.foundry.app import App - -from labelbox.schema.foundry.foundry_client import FoundryClient - -# Yolo object detection model id -TEST_MODEL_ID = "e8b352ce-8f3a-4cd6-93a5-8af904307346" - - -@pytest.fixture() -def random_str(rand_gen): - return rand_gen(str) - - -@pytest.fixture(scope="module") -def foundry_client(client): - return FoundryClient(client) - - -@pytest.fixture() -def text_data_row(dataset, random_str): - global_key = "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-1.txt-{random_str}" - task = dataset.create_data_rows([{ - "row_data": - "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-1.txt", - "media_type": - "TEXT", - "global_key": - global_key - }]) - task.wait_till_done() - dr = dataset.data_rows().get_one() - yield dr - dr.delete() - - -@pytest.fixture() -def ontology(client, random_str): - object_features = [ - lb.Tool(tool=lb.Tool.Type.BBOX, - name="text", - color="#ff0000", - classifications=[ - lb.Classification(class_type=lb.Classification.Type.TEXT, - name="value") - ]) - ] - - ontology_builder = lb.OntologyBuilder(tools=object_features,) - - ontology = client.create_ontology( - f"Test ontology for tesseract model {random_str}", - ontology_builder.asdict(), - media_type=lb.MediaType.Image) - return ontology - - -@pytest.fixture() -def unsaved_app(random_str, ontology): - return App(model_id=TEST_MODEL_ID, - name=f"Test App {random_str}", - description="Test App Description", - inference_params={"confidence": 0.2}, - class_to_schema_id={}, - ontology_id=ontology.uid) - - -@pytest.fixture() -def app(foundry_client, unsaved_app): - app = foundry_client._create_app(unsaved_app) - yield app - foundry_client._delete_app(app.id) - - -def test_create_app(foundry_client, unsaved_app): - app = foundry_client._create_app(unsaved_app) - retrieved_dict = app.dict(exclude={'id', 'created_by'}) - expected_dict = app.dict(exclude={'id', 'created_by'}) - assert retrieved_dict == expected_dict - - -def test_get_app(foundry_client, app): - retrieved_app = foundry_client._get_app(app.id) - retrieved_dict = retrieved_app.dict(exclude={'created_by'}) - expected_dict = app.dict(exclude={'created_by'}) - assert retrieved_dict == expected_dict - - -def test_get_app_with_invalid_id(foundry_client): - with pytest.raises(lb.exceptions.ResourceNotFoundError): - foundry_client._get_app("invalid-id") - - -def test_run_foundry_app_with_data_row_id(foundry_client, data_row, app, - random_str): - data_rows = lb.DataRowIds([data_row.uid]) - task = foundry_client.run_app( - model_run_name=f"test-app-with-datarow-id-{random_str}", - data_rows=data_rows, - app_id=app.id) - task.wait_till_done() - assert task.status == 'COMPLETE' - - -def test_run_foundry_app_with_global_key(foundry_client, data_row, app, - random_str): - data_rows = lb.GlobalKeys([data_row.global_key]) - task = foundry_client.run_app( - model_run_name=f"test-app-with-global-key-{random_str}", - data_rows=data_rows, - app_id=app.id) - task.wait_till_done() - assert task.status == 'COMPLETE' - - -def test_run_foundry_app_returns_model_run_id(foundry_client, data_row, app): - data_rows = lb.GlobalKeys([data_row.global_key]) - task = foundry_client.run_app( - model_run_name=f"test-app-with-global-key-{random_str}", - data_rows=data_rows, - app_id=app.id) - model_run_id = task.metadata['modelRunId'] - model_run = foundry_client.client.get_model_run(model_run_id) - assert model_run.uid == model_run_id - - -def test_run_foundry_with_invalid_data_row_id(foundry_client, app, random_str): - invalid_datarow_id = 'invalid-global-key' - data_rows = lb.GlobalKeys([invalid_datarow_id]) - with pytest.raises(lb.exceptions.LabelboxError) as exception: - foundry_client.run_app( - model_run_name=f"test-app-with-invalid-datarow-id-{random_str}", - data_rows=data_rows, - app_id=app.id) - assert invalid_datarow_id in exception.value - - -def test_run_foundry_with_invalid_global_key(foundry_client, app, random_str): - invalid_global_key = 'invalid-global-key' - data_rows = lb.GlobalKeys([invalid_global_key]) - with pytest.raises(lb.exceptions.LabelboxError) as exception: - foundry_client.run_app( - model_run_name=f"test-app-with-invalid-global-key-{random_str}", - data_rows=data_rows, - app_id=app.id) - assert invalid_global_key in exception.value - ----- -tests/integration/test_global_keys.py -import uuid -import pytest - - -def test_assign_global_keys_to_data_rows(client, dataset, image_url): - """Test that the assign_global_keys_to_data_rows method can be called - with a valid list of AssignGlobalKeyToDataRowInput objects. - """ - - dr_1 = dataset.create_data_row(row_data=image_url, external_id="hello") - dr_2 = dataset.create_data_row(row_data=image_url, external_id="world") - row_ids = set([dr_1.uid, dr_2.uid]) - - gk_1 = str(uuid.uuid4()) - gk_2 = str(uuid.uuid4()) - - assignment_inputs = [{ - "data_row_id": dr_1.uid, - "global_key": gk_1 - }, { - "data_row_id": dr_2.uid, - "global_key": gk_2 - }] - res = client.assign_global_keys_to_data_rows(assignment_inputs) - assert res['status'] == "SUCCESS" - assert res['errors'] == [] - - assert len(res['results']) == 2 - for r in res['results']: - del r['sanitized'] - assert res['results'] == assignment_inputs - - -def test_assign_global_keys_to_data_rows_validation_error(client): - assignment_inputs = [{ - "data_row_id": "test uid", - "wrong_key": "gk 1" - }, { - "data_row_id": "test uid 2", - "global_key": "gk 2" - }, { - "wrong_key": "test uid 3", - "global_key": "gk 3" - }, { - "data_row_id": "test uid 4" - }, { - "global_key": "gk 5" - }, {}] - with pytest.raises(ValueError) as excinfo: - client.assign_global_keys_to_data_rows(assignment_inputs) - e = """[{'data_row_id': 'test uid', 'wrong_key': 'gk 1'}, {'wrong_key': 'test uid 3', 'global_key': 'gk 3'}, {'data_row_id': 'test uid 4'}, {'global_key': 'gk 5'}, {}]""" - assert e in str(excinfo.value) - - -def test_assign_same_global_keys_to_data_rows(client, dataset, image_url): - dr_1 = dataset.create_data_row(row_data=image_url, external_id="hello") - dr_2 = dataset.create_data_row(row_data=image_url, external_id="world") - - gk_1 = str(uuid.uuid4()) - - assignment_inputs = [{ - "data_row_id": dr_1.uid, - "global_key": gk_1 - }, { - "data_row_id": dr_2.uid, - "global_key": gk_1 - }] - res = client.assign_global_keys_to_data_rows(assignment_inputs) - - assert res['status'] == "PARTIAL SUCCESS" - assert len(res['results']) == 1 - assert res['results'][0]['data_row_id'] == dr_1.uid - assert res['results'][0]['global_key'] == gk_1 - - assert len(res['errors']) == 1 - assert res['errors'][0]['data_row_id'] == dr_2.uid - assert res['errors'][0]['global_key'] == gk_1 - assert res['errors'][0][ - 'error'] == "Invalid assignment. Either DataRow does not exist, or globalKey is invalid" - - -def test_long_global_key_validation(client, dataset, image_url): - long_global_key = 'x' * 201 - dr_1 = dataset.create_data_row(row_data=image_url) - dr_2 = dataset.create_data_row(row_data=image_url) - - gk_1 = str(uuid.uuid4()) - gk_2 = long_global_key - - assignment_inputs = [{ - "data_row_id": dr_1.uid, - "global_key": gk_1 - }, { - "data_row_id": dr_2.uid, - "global_key": gk_2 - }] - res = client.assign_global_keys_to_data_rows(assignment_inputs) - - assert len(res['results']) == 1 - assert len(res['errors']) == 1 - assert res['status'] == 'PARTIAL SUCCESS' - assert res['results'][0]['data_row_id'] == dr_1.uid - assert res['results'][0]['global_key'] == gk_1 - assert res['errors'][0]['data_row_id'] == dr_2.uid - assert res['errors'][0]['global_key'] == gk_2 - assert res['errors'][0][ - 'error'] == 'Invalid assignment. Either DataRow does not exist, or globalKey is invalid' - - -def test_global_key_with_whitespaces_validation(client, dataset, image_url): - dr_1 = dataset.create_data_row(row_data=image_url) - dr_2 = dataset.create_data_row(row_data=image_url) - dr_3 = dataset.create_data_row(row_data=image_url) - - gk_1 = ' global key' - gk_2 = 'global key' - gk_3 = 'global key ' - - assignment_inputs = [{ - "data_row_id": dr_1.uid, - "global_key": gk_1 - }, { - "data_row_id": dr_2.uid, - "global_key": gk_2 - }, { - "data_row_id": dr_3.uid, - "global_key": gk_3 - }] - res = client.assign_global_keys_to_data_rows(assignment_inputs) - - assert len(res['results']) == 0 - assert len(res['errors']) == 3 - assert res['status'] == 'FAILURE' - assign_errors_ids = set([e['data_row_id'] for e in res['errors']]) - assign_errors_gks = set([e['global_key'] for e in res['errors']]) - assign_errors_msgs = set([e['error'] for e in res['errors']]) - assert assign_errors_ids == set([dr_1.uid, dr_2.uid, dr_3.uid]) - assert assign_errors_gks == set([gk_1, gk_2, gk_3]) - assert assign_errors_msgs == set([ - 'Invalid assignment. Either DataRow does not exist, or globalKey is invalid', - 'Invalid assignment. Either DataRow does not exist, or globalKey is invalid', - 'Invalid assignment. Either DataRow does not exist, or globalKey is invalid' - ]) - - -def test_get_data_row_ids_for_global_keys(client, dataset, image_url): - gk_1 = str(uuid.uuid4()) - gk_2 = str(uuid.uuid4()) - - dr_1 = dataset.create_data_row(row_data=image_url, - external_id="hello", - global_key=gk_1) - dr_2 = dataset.create_data_row(row_data=image_url, - external_id="world", - global_key=gk_2) - - res = client.get_data_row_ids_for_global_keys([gk_1]) - assert res['status'] == "SUCCESS" - assert res['errors'] == [] - assert res['results'] == [dr_1.uid] - - res = client.get_data_row_ids_for_global_keys([gk_2]) - assert res['status'] == "SUCCESS" - assert res['errors'] == [] - assert res['results'] == [dr_2.uid] - - res = client.get_data_row_ids_for_global_keys([gk_1, gk_2]) - assert res['status'] == "SUCCESS" - assert res['errors'] == [] - assert res['results'] == [dr_1.uid, dr_2.uid] - - -def test_get_data_row_ids_for_invalid_global_keys(client, dataset, image_url): - gk_1 = str(uuid.uuid4()) - gk_2 = str(uuid.uuid4()) - - dr_1 = dataset.create_data_row(row_data=image_url, external_id="hello") - dr_2 = dataset.create_data_row(row_data=image_url, - external_id="world", - global_key=gk_2) - - res = client.get_data_row_ids_for_global_keys([gk_1]) - assert res['status'] == "FAILURE" - assert len(res['errors']) == 1 - assert res['errors'][0]['error'] == "Data Row not found" - assert res['errors'][0]['global_key'] == gk_1 - - res = client.get_data_row_ids_for_global_keys([gk_1, gk_2]) - assert res['status'] == "PARTIAL SUCCESS" - - assert len(res['errors']) == 1 - assert len(res['results']) == 2 - - assert res['errors'][0]['error'] == "Data Row not found" - assert res['errors'][0]['global_key'] == gk_1 - - assert res['results'][0] == '' - assert res['results'][1] == dr_2.uid - ----- -tests/integration/test_project.py -import time -import os -import uuid - -import pytest -import requests - -from labelbox import Project, LabelingFrontend, Dataset -from labelbox.exceptions import InvalidQueryError -from labelbox.schema.media_type import MediaType -from labelbox.schema.quality_mode import QualityMode -from labelbox.schema.queue_mode import QueueMode - - -def test_project(client, rand_gen): - data = { - "name": rand_gen(str), - "description": rand_gen(str), - "queue_mode": QueueMode.Batch.Batch, - "media_type": MediaType.Image, - } - project = client.create_project(**data) - assert project.name == data["name"] - assert project.description == data["description"] - - project = client.get_project(project.uid) - assert project.name == data["name"] - assert project.description == data["description"] - - update_data = {"name": rand_gen(str), "description": rand_gen(str)} - project.update(**update_data) - # Test local object updates. - assert project.name == update_data["name"] - assert project.description == update_data["description"] - - # Test remote updates. - project = client.get_project(project.uid) - assert project.name == update_data["name"] - assert project.description == update_data["description"] - - project.delete() - projects = list(client.get_projects()) - assert project not in projects - - -@pytest.fixture -def data_for_project_test(client, rand_gen): - projects = [] - - def _create_project(name: str = None): - if name is None: - name = rand_gen(str) - project = client.create_project(name=name) - projects.append(project) - return project - - yield _create_project - - for project in projects: - project.delete() - - -def test_update_project_resource_tags(client, rand_gen, data_for_project_test): - p1 = data_for_project_test() - - def delete_tag(tag_id: str): - """Deletes a tag given the tag uid. Currently internal use only so this is not public""" - res = client.execute( - """mutation deleteResourceTagPyApi($tag_id: String!) { - deleteResourceTag(input: {id: $tag_id}) { - id - } - } - """, {"tag_id": tag_id}) - return res - - org = client.get_organization() - assert org.uid is not None - - assert p1.uid is not None - - colorA = "#ffffff" - textA = rand_gen(str) - tag = {"text": textA, "color": colorA} - - colorB = colorA - textB = rand_gen(str) - tagB = {"text": textB, "color": colorB} - - tagA = client.get_organization().create_resource_tag(tag) - assert tagA.text == textA - assert '#' + tagA.color == colorA - assert tagA.uid is not None - - tags = org.get_resource_tags() - lenA = len(tags) - assert lenA > 0 - - tagB = client.get_organization().create_resource_tag(tagB) - assert tagB.text == textB - assert '#' + tagB.color == colorB - assert tagB.uid is not None - - tags = client.get_organization().get_resource_tags() - lenB = len(tags) - assert lenB > 0 - assert lenB > lenA - - project_resource_tag = client.get_project( - p1.uid).update_project_resource_tags([str(tagA.uid)]) - assert len(project_resource_tag) == 1 - assert project_resource_tag[0].uid == tagA.uid - - project_resource_tags = client.get_project(p1.uid).get_resource_tags() - assert len(project_resource_tags) == 1 - assert project_resource_tags[0].uid == tagA.uid - - delete_tag(tagA.uid) - delete_tag(tagB.uid) - - -def test_project_filtering(client, rand_gen, data_for_project_test): - name_1 = rand_gen(str) - p1 = data_for_project_test(name_1) - name_2 = rand_gen(str) - p2 = data_for_project_test(name_2) - - assert list(client.get_projects(where=Project.name == name_1)) == [p1] - assert list(client.get_projects(where=Project.name == name_2)) == [p2] - - -def test_extend_reservations(project): - assert project.extend_reservations("LabelingQueue") == 0 - assert project.extend_reservations("ReviewQueue") == 0 - with pytest.raises(InvalidQueryError): - project.extend_reservations("InvalidQueueType") - - -@pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", - reason="new mutation does not work for onprem") -def test_attach_instructions(client, project): - with pytest.raises(ValueError) as execinfo: - project.upsert_instructions('tests/integration/media/sample_pdf.pdf') - assert str( - execinfo.value - ) == "Cannot attach instructions to a project that has not been set up." - - editor = list( - client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - empty_ontology = {"tools": [], "classifications": []} - project.setup(editor, empty_ontology) - - project.upsert_instructions('tests/integration/media/sample_pdf.pdf') - time.sleep(3) - assert project.ontology().normalized['projectInstructions'] is not None - - with pytest.raises(ValueError) as exc_info: - project.upsert_instructions('/tmp/file.invalid_file_extension') - assert "instructions_file must be a pdf or html file. Found" in str( - exc_info.value) - - -@pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", - reason="new mutation does not work for onprem") -def test_html_instructions(project_with_empty_ontology): - html_file_path = '/tmp/instructions.html' - sample_html_str = "" - - with open(html_file_path, 'w') as file: - file.write(sample_html_str) - - project_with_empty_ontology.upsert_instructions(html_file_path) - updated_ontology = project_with_empty_ontology.ontology().normalized - - instructions = updated_ontology.pop('projectInstructions') - assert requests.get(instructions).text == sample_html_str - - -@pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", - reason="new mutation does not work for onprem") -def test_same_ontology_after_instructions( - configured_project_with_complex_ontology): - project, _ = configured_project_with_complex_ontology - initial_ontology = project.ontology().normalized - project.upsert_instructions('tests/assets/loremipsum.pdf') - updated_ontology = project.ontology().normalized - - instructions = updated_ontology.pop('projectInstructions') - - assert initial_ontology == updated_ontology - assert instructions is not None - - -def test_batches(project: Project, dataset: Dataset, image_url): - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image" - }, - ] * 2) - task.wait_till_done() - # TODO: Move to export_v2 - data_rows = [dr.uid for dr in list(dataset.export_data_rows())] - batch_one = f'batch one {uuid.uuid4()}' - batch_two = f'batch two {uuid.uuid4()}' - project.create_batch(batch_one, [data_rows[0]]) - project.create_batch(batch_two, [data_rows[1]]) - - names = set([batch.name for batch in list(project.batches())]) - assert names == {batch_one, batch_two} - - -@pytest.mark.parametrize('data_rows', [2], indirect=True) -def test_create_batch_with_global_keys_sync(project: Project, data_rows): - global_keys = [dr.global_key for dr in data_rows] - batch_name = f'batch {uuid.uuid4()}' - batch = project.create_batch(batch_name, global_keys=global_keys) - # TODO: Move to export_v2 - batch_data_rows = set(batch.export_data_rows()) - assert batch_data_rows == set(data_rows) - - -@pytest.mark.parametrize('data_rows', [2], indirect=True) -def test_create_batch_with_global_keys_async(project: Project, data_rows): - global_keys = [dr.global_key for dr in data_rows] - batch_name = f'batch {uuid.uuid4()}' - batch = project._create_batch_async(batch_name, global_keys=global_keys) - # TODO: Move to export_v2 - batch_data_rows = set(batch.export_data_rows()) - assert batch_data_rows == set(data_rows) - - -def test_media_type(client, project: Project, rand_gen): - # Existing project with no media_type - assert isinstance(project.media_type, MediaType) - - # Update test - project = client.create_project(name=rand_gen(str)) - project.update(media_type=MediaType.Image) - assert project.media_type == MediaType.Image - project.delete() - - for media_type in MediaType.get_supported_members(): - # Exclude LLM media types for now, as they are not supported - if MediaType[media_type] in [ - MediaType.LLMPromptCreation, MediaType.LLMPromptResponseCreation - ]: - continue - - project = client.create_project(name=rand_gen(str), - media_type=MediaType[media_type]) - assert project.media_type == MediaType[media_type] - project.delete() - - -def test_queue_mode(client, rand_gen): - project = client.create_project(name=rand_gen(str)) # defaults to benchmark - assert project.auto_audit_number_of_labels == 1 - assert project.auto_audit_percentage == 1 - - project = client.create_project(name=rand_gen(str), - quality_mode=QualityMode.Benchmark) - assert project.auto_audit_number_of_labels == 1 - assert project.auto_audit_percentage == 1 - - project = client.create_project(name=rand_gen(str), - quality_mode=QualityMode.Consensus) - assert project.auto_audit_number_of_labels == 3 - assert project.auto_audit_percentage == 0 - - -def test_label_count(client, configured_batch_project_with_label): - project = client.create_project(name="test label count") - assert project.get_label_count() == 0 - project.delete() - - [source_project, _, _, _] = configured_batch_project_with_label - num_labels = sum([1 for _ in source_project.labels()]) - assert source_project.get_label_count() == num_labels - ----- -tests/integration/test_batch.py -import time -from typing import List -from uuid import uuid4 - -import pytest - -from labelbox import Dataset, Project -from labelbox.exceptions import ProcessingWaitTimeout, MalformedQueryException, ResourceConflict, LabelboxError -from integration.conftest import upload_invalid_data_rows_for_dataset, IMAGE_URL, EXTERNAL_ID - - -def get_data_row_ids(ds: Dataset): - return [dr.uid for dr in list(ds.export_data_rows())] - - -def test_create_batch(project: Project, big_dataset_data_row_ids: List[str]): - batch = project.create_batch("test-batch", - big_dataset_data_row_ids, - 3, - consensus_settings={ - 'number_of_labels': 3, - 'coverage_percentage': 0.1 - }) - - assert batch.name == "test-batch" - assert batch.size == len(big_dataset_data_row_ids) - assert len([dr for dr in batch.failed_data_row_ids]) == 0 - - -def test_create_batch_with_invalid_data_rows_ids(project: Project): - with pytest.raises(MalformedQueryException) as ex: - project.create_batch("test-batch", data_rows=['a', 'b', 'c']) - assert str( - ex) == "No valid data rows to be added from the list provided!" - - -def test_create_batch_with_the_same_name(project: Project, - small_dataset: Dataset): - batch1 = project.create_batch("batch1", - data_rows=get_data_row_ids(small_dataset)) - assert batch1.name == "batch1" - - with pytest.raises(ResourceConflict): - project.create_batch("batch1", - data_rows=get_data_row_ids(small_dataset)) - - -def test_create_batch_with_same_data_row_ids(project: Project, - small_dataset: Dataset): - batch1 = project.create_batch("batch1", - data_rows=get_data_row_ids(small_dataset)) - assert batch1.name == "batch1" - - with pytest.raises(MalformedQueryException) as ex: - project.create_batch("batch2", - data_rows=get_data_row_ids(small_dataset)) - assert str(ex) == "No valid data rows to add to project" - - -def test_create_batch_with_non_existent_global_keys(project: Project): - with pytest.raises(MalformedQueryException) as ex: - project.create_batch("batch1", global_keys=["key1"]) - assert str( - ex - ) == "Data rows with the following global keys do not exist: key1." - - -def test_create_batch_with_string_priority(project: Project, - small_dataset: Dataset): - with pytest.raises(LabelboxError): - project.create_batch("batch1", - data_rows=get_data_row_ids(small_dataset), - priority="abcd") - - -def test_create_batch_with_null_priority(project: Project, - small_dataset: Dataset): - with pytest.raises(LabelboxError): - project.create_batch("batch1", - data_rows=get_data_row_ids(small_dataset), - priority=None) - - -def test_create_batch_async(project: Project, - big_dataset_data_row_ids: List[str]): - batch = project._create_batch_async("big-batch", - big_dataset_data_row_ids, - priority=3) - assert batch.name == "big-batch" - assert batch.size == len(big_dataset_data_row_ids) - assert len([dr for dr in batch.failed_data_row_ids]) == 0 - - -def test_create_batch_with_consensus_settings(project: Project, - small_dataset: Dataset): - data_rows = [dr.uid for dr in list(small_dataset.export_data_rows())] - consensus_settings = {"coverage_percentage": 0.1, "number_of_labels": 3} - batch = project.create_batch("batch with consensus settings", - data_rows, - 3, - consensus_settings=consensus_settings) - assert batch.name == "batch with consensus settings" - assert batch.size == len(data_rows) - assert batch.consensus_settings == consensus_settings - - -def test_create_batch_with_data_row_class(project: Project, - small_dataset: Dataset): - data_rows = list(small_dataset.export_data_rows()) - batch = project.create_batch("test-batch-data-rows", data_rows, 3) - assert batch.name == "test-batch-data-rows" - assert batch.size == len(data_rows) - - -def test_archive_batch(project: Project, small_dataset: Dataset): - data_rows = [dr.uid for dr in list(small_dataset.export_data_rows())] - batch = project.create_batch("batch to archive", data_rows) - batch.remove_queued_data_rows() - exported_data_rows = list(batch.export_data_rows()) - - assert len(exported_data_rows) == 0 - - -def test_delete(project: Project, small_dataset: Dataset): - data_rows = [dr.uid for dr in list(small_dataset.export_data_rows())] - batch = project.create_batch("batch to delete", data_rows) - batch.delete() - - assert len(list(project.batches())) == 0 - - -def test_batch_project(project: Project, small_dataset: Dataset): - data_rows = [dr.uid for dr in list(small_dataset.export_data_rows())] - batch = project.create_batch("batch to test project relationship", - data_rows) - - project_from_batch = batch.project() - - assert project_from_batch.uid == project.uid - assert project_from_batch.name == project.name - - -def test_batch_creation_for_data_rows_with_issues( - project: Project, small_dataset: Dataset, - dataset_with_invalid_data_rows: Dataset): - """ - Create a batch containing both valid and invalid data rows - """ - valid_data_rows = [dr.uid for dr in list(small_dataset.data_rows())] - invalid_data_rows = [ - dr.uid for dr in list(dataset_with_invalid_data_rows.data_rows()) - ] - data_rows_to_add = valid_data_rows + invalid_data_rows - - assert len(data_rows_to_add) == 4 - batch = project.create_batch("batch to test failed data rows", - data_rows_to_add) - failed_data_row_ids = [x for x in batch.failed_data_row_ids] - assert len(failed_data_row_ids) == 2 - - failed_data_row_ids_set = set(failed_data_row_ids) - invalid_data_rows_set = set(invalid_data_rows) - assert len(failed_data_row_ids_set.intersection(invalid_data_rows_set)) == 2 - - -def test_batch_creation_with_processing_timeout(project: Project, - small_dataset: Dataset, - unique_dataset: Dataset): - """ - Create a batch with zero wait time, this means that the waiting logic will throw exception immediately - """ - # wait for these data rows to be processed - valid_data_rows = [dr.uid for dr in list(small_dataset.data_rows())] - - # upload data rows for this dataset and don't wait - upload_invalid_data_rows_for_dataset(unique_dataset) - unprocessed_data_rows = [dr.uid for dr in list(unique_dataset.data_rows())] - - data_row_ids = valid_data_rows + unprocessed_data_rows - - stashed_wait_timeout = project._wait_processing_max_seconds - with pytest.raises(ProcessingWaitTimeout): - # emulate the situation where there are still some data rows being - # processed but wait timeout exceeded - project._wait_processing_max_seconds = 0 - project.create_batch("batch to test failed data rows", data_row_ids) - project._wait_processing_max_seconds = stashed_wait_timeout - - -def test_export_data_rows(project: Project, dataset: Dataset): - n_data_rows = 2 - task = dataset.create_data_rows([ - { - "row_data": IMAGE_URL, - "external_id": EXTERNAL_ID - }, - ] * n_data_rows) - task.wait_till_done() - - data_rows = [dr.uid for dr in list(dataset.export_data_rows())] - batch = project.create_batch("batch test", data_rows) - result = list(batch.export_data_rows()) - exported_data_rows = [dr.uid for dr in result] - - assert len(result) == n_data_rows - assert set(data_rows) == set(exported_data_rows) - - -def test_list_all_batches(project: Project, client): - """ - Test to verify that we can retrieve all available batches in the project. - """ - # Data to use - img_assets = [{ - "row_data": IMAGE_URL, - "external_id": str(uuid4()) - } for asset in range(0, 2)] - data = [img_assets for _ in range(0, 2)] - - # Setup - batches = [] - datasets = [] - - for assets in data: - dataset = client.create_dataset(name=str(uuid4())) - create_data_rows_task = dataset.create_data_rows(assets) - create_data_rows_task.wait_till_done() - datasets.append(dataset) - - for dataset in datasets: - data_row_ids = get_data_row_ids(dataset) - new_batch = project.create_batch(name=str(uuid4()), - data_rows=data_row_ids) - batches.append(new_batch) - - # Test - project_batches = list(project.batches()) - assert len(batches) == len(project_batches) - - for project_batch in project_batches: - for assets in data: - assert len(assets) == project_batch.size - - # Clean up - for dataset in datasets: - dataset.delete() - - -def test_list_project_batches_with_no_batches(project: Project): - batches = list(project.batches()) - assert len(batches) == 0 - - -@pytest.mark.skip( - reason="Test cannot be used effectively with MAL/LabelImport. \ -Fix/Unskip after resolving deletion with MAL/LabelImport") -def test_delete_labels(project, small_dataset): - data_rows = [dr.uid for dr in list(small_dataset.export_data_rows())] - batch = project.create_batch("batch to delete labels", data_rows) - - -@pytest.mark.skip( - reason="Test cannot be used effectively with MAL/LabelImport. \ -Fix/Unskip after resolving deletion with MAL/LabelImport") -def test_delete_labels_with_templates(project: Project, small_dataset: Dataset): - data_rows = [dr.uid for dr in list(small_dataset.export_data_rows())] - batch = project.create_batch("batch to delete labels w templates", - data_rows) - exported_data_rows = list(batch.export_data_rows()) - res = batch.delete_labels(labels_as_template=True) - exported_data_rows = list(batch.export_data_rows()) - assert len(exported_data_rows) == 5 - ----- -tests/integration/test_ontology.py -import pytest - -from labelbox import OntologyBuilder, MediaType, Tool -from labelbox.orm.model import Entity -import json -import time - -from labelbox.schema.queue_mode import QueueMode - - -def test_feature_schema_is_not_archived(client, ontology): - feature_schema_to_check = ontology.normalized['tools'][0] - result = client.is_feature_schema_archived( - ontology.uid, feature_schema_to_check['featureSchemaId']) - assert result == False - - -def test_feature_schema_is_archived(client, configured_project_with_label): - project, _, _, label = configured_project_with_label - ontology = project.ontology() - feature_schema_id = ontology.normalized['tools'][0]['featureSchemaId'] - result = client.delete_feature_schema_from_ontology(ontology.uid, - feature_schema_id) - assert result.archived == True and result.deleted == False - assert client.is_feature_schema_archived(ontology.uid, - feature_schema_id) == True - - -def test_is_feature_schema_archived_for_non_existing_feature_schema( - client, ontology): - with pytest.raises( - Exception, - match="The specified feature schema was not in the ontology"): - client.is_feature_schema_archived(ontology.uid, - 'invalid-feature-schema-id') - - -def test_is_feature_schema_archived_for_non_existing_ontology(client, ontology): - feature_schema_to_unarchive = ontology.normalized['tools'][0] - with pytest.raises( - Exception, - match="Resource 'Ontology' not found for params: 'invalid-ontology'" - ): - client.is_feature_schema_archived( - 'invalid-ontology', feature_schema_to_unarchive['featureSchemaId']) - - -def test_delete_tool_feature_from_ontology(client, ontology): - feature_schema_to_delete = ontology.normalized['tools'][0] - assert len(ontology.normalized['tools']) == 2 - result = client.delete_feature_schema_from_ontology( - ontology.uid, feature_schema_to_delete['featureSchemaId']) - assert result.deleted == True - assert result.archived == False - updatedOntology = client.get_ontology(ontology.uid) - assert len(updatedOntology.normalized['tools']) == 1 - - -@pytest.mark.skip(reason="normalized ontology contains Relationship, " - "which is not finalized yet. introduce this back when" - "Relationship feature is complete and we introduce" - "a Relationship object to the ontology that we can parse") -def test_from_project_ontology(project) -> None: - o = OntologyBuilder.from_project(project) - assert o.asdict() == project.ontology().normalized - - -point = Tool( - tool=Tool.Type.POINT, - name="name", - color="#ff0000", -) - - -def test_deletes_an_ontology(client): - tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] - ontology = client.create_ontology_from_feature_schemas( - name='ontology name', - feature_schema_ids=[feature_schema_id], - media_type=MediaType.Image) - - assert client.delete_unused_ontology(ontology.uid) is None - - client.delete_unused_feature_schema(feature_schema_id) - - -def test_cant_delete_an_ontology_with_project(client): - project = client.create_project(name="test project", - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) - tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] - ontology = client.create_ontology_from_feature_schemas( - name='ontology name', - feature_schema_ids=[feature_schema_id], - media_type=MediaType.Image) - project.setup_editor(ontology) - - with pytest.raises( - Exception, - match= - "Failed to delete the ontology, message: Cannot delete an ontology connected to a project. The ontology is connected to projects: " - + project.uid): - client.delete_unused_ontology(ontology.uid) - - project.delete() - client.delete_unused_ontology(ontology.uid) - client.delete_unused_feature_schema(feature_schema_id) - - -def test_inserts_a_feature_schema_at_given_position(client): - tool1 = {'tool': 'polygon', 'name': 'tool1', 'color': 'blue'} - tool2 = {'tool': 'polygon', 'name': 'tool2', 'color': 'blue'} - ontology_normalized_json = {"tools": [tool1, tool2], "classifications": []} - ontology = client.create_ontology(name="ontology", - normalized=ontology_normalized_json, - media_type=MediaType.Image) - created_feature_schema = client.upsert_feature_schema(point.asdict()) - client.insert_feature_schema_into_ontology( - created_feature_schema.normalized['featureSchemaId'], ontology.uid, 1) - ontology = client.get_ontology(ontology.uid) - - assert ontology.normalized['tools'][1][ - 'schemaNodeId'] == created_feature_schema.normalized['schemaNodeId'] - - client.delete_unused_ontology(ontology.uid) - - -def test_moves_already_added_feature_schema_in_ontology(client): - tool1 = {'tool': 'polygon', 'name': 'tool1', 'color': 'blue'} - ontology_normalized_json = {"tools": [tool1], "classifications": []} - ontology = client.create_ontology(name="ontology", - normalized=ontology_normalized_json, - media_type=MediaType.Image) - created_feature_schema = client.upsert_feature_schema(point.asdict()) - feature_schema_id = created_feature_schema.normalized['featureSchemaId'] - client.insert_feature_schema_into_ontology(feature_schema_id, ontology.uid, - 1) - ontology = client.get_ontology(ontology.uid) - assert ontology.normalized['tools'][1][ - 'schemaNodeId'] == created_feature_schema.normalized['schemaNodeId'] - client.insert_feature_schema_into_ontology(feature_schema_id, ontology.uid, - 0) - ontology = client.get_ontology(ontology.uid) - - assert ontology.normalized['tools'][0][ - 'schemaNodeId'] == created_feature_schema.normalized['schemaNodeId'] - - client.delete_unused_ontology(ontology.uid) - - -def test_does_not_include_used_ontologies(client): - tool = client.upsert_feature_schema(point.asdict()) - feature_schema_id = tool.normalized['featureSchemaId'] - ontology_with_project = client.create_ontology_from_feature_schemas( - name='ontology name', - feature_schema_ids=[feature_schema_id], - media_type=MediaType.Image) - project = client.create_project(name="test project", - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) - project.setup_editor(ontology_with_project) - unused_ontologies = client.get_unused_ontologies() - - assert ontology_with_project.uid not in unused_ontologies - - project.delete() - client.delete_unused_ontology(ontology_with_project.uid) - client.delete_unused_feature_schema(feature_schema_id) - - -def _get_attr_stringify_json(obj, attr): - value = getattr(obj, attr.name) - if attr.field_type.name.lower() == "json": - return json.dumps(value, sort_keys=True) - return value - - -def test_feature_schema_create_read(client, rand_gen): - name = f"test-root-schema-{rand_gen(str)}" - feature_schema_cat_normalized = { - 'tool': 'polygon', - 'name': name, - 'color': 'black', - 'classifications': [], - } - created_feature_schema = client.create_feature_schema( - feature_schema_cat_normalized) - queried_feature_schema = client.get_feature_schema( - created_feature_schema.uid) - for attr in Entity.FeatureSchema.fields(): - assert _get_attr_stringify_json(created_feature_schema, - attr) == _get_attr_stringify_json( - queried_feature_schema, attr) - - time.sleep(3) # Slight delay for searching - queried_feature_schemas = list(client.get_feature_schemas(name)) - assert [feature_schema.name for feature_schema in queried_feature_schemas - ] == [name] - queried_feature_schema = queried_feature_schemas[0] - - for attr in Entity.FeatureSchema.fields(): - assert _get_attr_stringify_json(created_feature_schema, - attr) == _get_attr_stringify_json( - queried_feature_schema, attr) - - -def test_ontology_create_read(client, rand_gen): - ontology_name = f"test-ontology-{rand_gen(str)}" - tool_name = f"test-ontology-tool-{rand_gen(str)}" - feature_schema_cat_normalized = { - 'tool': 'polygon', - 'name': tool_name, - 'color': 'black', - 'classifications': [], - } - feature_schema = client.create_feature_schema(feature_schema_cat_normalized) - created_ontology = client.create_ontology_from_feature_schemas( - name=ontology_name, - feature_schema_ids=[feature_schema.uid], - media_type=MediaType.Image) - tool_normalized = created_ontology.normalized['tools'][0] - for k, v in feature_schema_cat_normalized.items(): - assert tool_normalized[k] == v - assert tool_normalized['schemaNodeId'] is not None - assert tool_normalized['featureSchemaId'] == feature_schema.uid - - queried_ontology = client.get_ontology(created_ontology.uid) - - for attr in Entity.Ontology.fields(): - assert _get_attr_stringify_json(created_ontology, - attr) == _get_attr_stringify_json( - queried_ontology, attr) - - time.sleep(3) # Slight delay for searching - queried_ontologies = list(client.get_ontologies(ontology_name)) - assert [ontology.name for ontology in queried_ontologies] == [ontology_name] - queried_ontology = queried_ontologies[0] - for attr in Entity.Ontology.fields(): - assert _get_attr_stringify_json(created_ontology, - attr) == _get_attr_stringify_json( - queried_ontology, attr) - - -def test_unarchive_feature_schema_node(client, ontology): - feature_schema_to_unarchive = ontology.normalized['tools'][0] - result = client.unarchive_feature_schema_node( - ontology.uid, feature_schema_to_unarchive['featureSchemaId']) - assert result == None - - -def test_unarchive_feature_schema_node_for_non_existing_feature_schema( - client, ontology): - with pytest.raises( - Exception, - match= - "Failed to find feature schema node by id: invalid-feature-schema-id" - ): - client.unarchive_feature_schema_node(ontology.uid, - 'invalid-feature-schema-id') - - -def test_unarchive_feature_schema_node_for_non_existing_ontology( - client, ontology): - feature_schema_to_unarchive = ontology.normalized['tools'][0] - with pytest.raises(Exception, - match="Failed to find ontology by id: invalid-ontology"): - client.unarchive_feature_schema_node( - 'invalid-ontology', feature_schema_to_unarchive['featureSchemaId']) - ----- -tests/integration/test_task_queue.py -import time - -from labelbox import Project -from labelbox.schema.identifiables import GlobalKeys, UniqueIds - - -def test_get_task_queue(project: Project): - task_queues = project.task_queues() - assert len(task_queues) == 3 - review_queue = next( - tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE") - assert review_queue - - -def _validate_moved(project, queue_name, data_row_count): - timeout_seconds = 30 - sleep_time = 2 - while True: - task_queues = project.task_queues() - review_queue = next( - tq for tq in task_queues if tq.queue_type == queue_name) - - if review_queue.data_row_count == data_row_count: - break - - if timeout_seconds <= 0: - raise AssertionError( - "Timed out expecting data_row_count of 1 in the review queue") - - timeout_seconds -= sleep_time - time.sleep(sleep_time) - - -def test_move_to_task(configured_batch_project_with_label): - project, _, data_row, _ = configured_batch_project_with_label - task_queues = project.task_queues() - - review_queue = next( - tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE") - project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid) - _validate_moved(project, "MANUAL_REVIEW_QUEUE", 1) - - review_queue = next( - tq for tq in task_queues if tq.queue_type == "MANUAL_REWORK_QUEUE") - project.move_data_rows_to_task_queue(GlobalKeys([data_row.global_key]), - review_queue.uid) - _validate_moved(project, "MANUAL_REWORK_QUEUE", 1) - - review_queue = next( - tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE") - project.move_data_rows_to_task_queue(UniqueIds([data_row.uid]), - review_queue.uid) - _validate_moved(project, "MANUAL_REVIEW_QUEUE", 1) - ----- -tests/integration/test_filtering.py -import pytest - -from labelbox import Project -from labelbox.exceptions import InvalidQueryError -from labelbox.schema.queue_mode import QueueMode - - -@pytest.fixture -def project_to_test_where(client, rand_gen): - p_a_name = f"a-{rand_gen(str)}" - p_b_name = f"b-{rand_gen(str)}" - p_c_name = f"c-{rand_gen(str)}" - - p_a = client.create_project(name=p_a_name, queue_mode=QueueMode.Batch) - p_b = client.create_project(name=p_b_name, queue_mode=QueueMode.Batch) - p_c = client.create_project(name=p_c_name, queue_mode=QueueMode.Batch) - - yield p_a, p_b, p_c - - p_a.delete() - p_b.delete() - p_c.delete() - - -# Avoid assertions using equality to prevent intermittent failures due to -# other builds simultaneously adding projects to test org -def test_where(client, project_to_test_where): - p_a, p_b, p_c = project_to_test_where - p_a_name, p_b_name = [p.name for p in [p_a, p_b]] - - def get(where=None): - date_where = Project.created_at >= p_a.created_at - where = date_where if where is None else where & date_where - return {p.uid for p in client.get_projects(where)} - - assert {p_a.uid, p_b.uid, p_c.uid}.issubset(get()) - e_a = get(Project.name == p_a_name) - assert p_a.uid in e_a and p_b.uid not in e_a and p_c.uid not in e_a - not_b = get(Project.name != p_b_name) - assert {p_a.uid, p_c.uid}.issubset(not_b) and p_b.uid not in not_b - gt_b = get(Project.name > p_b_name) - assert p_c.uid in gt_b and p_a.uid not in gt_b and p_b.uid not in gt_b - lt_b = get(Project.name < p_b_name) - assert p_a.uid in lt_b and p_b.uid not in lt_b and p_c.uid not in lt_b - ge_b = get(Project.name >= p_b_name) - assert {p_b.uid, p_c.uid}.issubset(ge_b) and p_a.uid not in ge_b - le_b = get(Project.name <= p_b_name) - assert {p_a.uid, p_b.uid}.issubset(le_b) and p_c.uid not in le_b - - -def test_unsupported_where(client): - with pytest.raises(InvalidQueryError): - client.get_projects(where=(Project.name == "a") & (Project.name == "b")) - - # TODO support logical OR and NOT in where - with pytest.raises(InvalidQueryError): - client.get_projects(where=(Project.name == "a") | - (Project.description == "b")) - - with pytest.raises(InvalidQueryError): - client.get_projects(where=~(Project.name == "a")) ----- -tests/integration/test_batches.py -from typing import List - -import pytest - -from labelbox import Project, Dataset - - -def test_create_batches(project: Project, big_dataset_data_row_ids: List[str]): - task = project.create_batches("test-batch", - big_dataset_data_row_ids, - priority=3) - - task.wait_till_done() - assert task.errors() is None - batches = task.result() - - assert len(batches) == 1 - assert batches[0].name == "test-batch0000" - assert batches[0].size == len(big_dataset_data_row_ids) - - -def test_create_batches_from_dataset(project: Project, big_dataset: Dataset): - data_rows = [dr.uid for dr in list(big_dataset.export_data_rows())] - project._wait_until_data_rows_are_processed(data_rows, [], 300) - - task = project.create_batches_from_dataset("test-batch", - big_dataset.uid, - priority=3) - - task.wait_till_done() - assert task.errors() is None - batches = task.result() - - assert len(batches) == 1 - assert batches[0].name == "test-batch0000" - assert batches[0].size == len(data_rows) - ----- -tests/integration/annotation_import/test_model.py -import pytest - -from labelbox import Model -from labelbox.exceptions import ResourceNotFoundError - - -def test_model(client, configured_project_with_one_data_row, rand_gen): - # Get all - models = list(client.get_models()) - for m in models: - assert isinstance(m, Model) - - # Create - ontology = configured_project_with_one_data_row.ontology() - data = {"name": rand_gen(str), "ontology_id": ontology.uid} - model = client.create_model(data["name"], data["ontology_id"]) - assert model.name == data["name"] - - # Get one - model = client.get_model(model.uid) - assert model.name == data["name"] - - # Delete - model.delete() - with pytest.raises(ResourceNotFoundError): - client.get_model(model.uid) - ----- -tests/integration/annotation_import/conftest.py -import uuid - -import pytest -import time -import requests - -from labelbox import parser, MediaType - -from typing import Type -from labelbox.schema.labeling_frontend import LabelingFrontend -from labelbox.schema.annotation_import import LabelImport, AnnotationImportState -from labelbox.schema.project import Project -from labelbox.schema.queue_mode import QueueMode - -DATA_ROW_PROCESSING_WAIT_TIMEOUT_SECONDS = 40 -DATA_ROW_PROCESSING_WAIT_SLEEP_INTERNAL_SECONDS = 7 - - -@pytest.fixture() -def audio_data_row(rand_gen): - return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/audio-sample-data/sample-audio-1.mp3", - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/audio-sample-data/sample-audio-1.mp3-{rand_gen(str)}", - "media_type": - "AUDIO", - } - - -@pytest.fixture() -def conversation_data_row(rand_gen): - return { - "row_data": - "https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json", - "global_key": - f"https://storage.googleapis.com/labelbox-developer-testing-assets/conversational_text/1000-conversations/conversation-1.json-{rand_gen(str)}", - } - - -@pytest.fixture() -def dicom_data_row(rand_gen): - return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/dicom-sample-data/sample-dicom-1.dcm", - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/dicom-sample-data/sample-dicom-1.dcm-{rand_gen(str)}", - "media_type": - "DICOM", - } - - -@pytest.fixture() -def geospatial_data_row(rand_gen): - return { - "row_data": { - "tile_layer_url": - "https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/{z}/{x}/{y}.png", - "bounds": [[19.405662413477728, -99.21052827588443], - [19.400498983095076, -99.20534818927473]], - "min_zoom": - 12, - "max_zoom": - 20, - "epsg": - "EPSG4326", - }, - "global_key": - f"https://s3-us-west-1.amazonaws.com/lb-tiler-layers/mexico_city/z/x/y.png-{rand_gen(str)}", - "media_type": - "TMS_GEO", - } - - -@pytest.fixture() -def html_data_row(rand_gen): - return { - "row_data": - "https://storage.googleapis.com/labelbox-datasets/html_sample_data/sample_html_1.html", - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/html_sample_data/sample_html_1.html-{rand_gen(str)}", - } - - -@pytest.fixture() -def image_data_row(rand_gen): - return { - "row_data": - "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg", - "global_key": - f"https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-{rand_gen(str)}", - "media_type": - "IMAGE", - } - - -@pytest.fixture() -def document_data_row(rand_gen): - return { - "row_data": { - "pdf_url": - "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf", - "text_layer_url": - "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483-lb-textlayer.json" - }, - "global_key": - f"https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf-{rand_gen(str)}", - "media_type": - "PDF", - } - - -@pytest.fixture() -def text_data_row(rand_gen): - return { - "row_data": - "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-1.txt", - "global_key": - f"https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample-text-1.txt-{rand_gen(str)}", - "media_type": - "TEXT", - } - - -@pytest.fixture() -def llm_prompt_creation_data_row(rand_gen): - return { - "row_data": { - "type": "application/llm.prompt-creation", - "version": 1 - }, - "global_key": rand_gen(str) - } - - -@pytest.fixture() -def llm_prompt_response_data_row(rand_gen): - return { - "row_data": { - "type": "application/llm.prompt-response-creation", - "version": 1 - }, - "global_key": rand_gen(str) - } - - -@pytest.fixture -def data_row_json_by_data_type(audio_data_row, conversation_data_row, - dicom_data_row, geospatial_data_row, - html_data_row, image_data_row, document_data_row, - text_data_row, video_data_row, - llm_prompt_creation_data_row, - llm_prompt_response_data_row): - return { - 'audio': audio_data_row, - 'conversation': conversation_data_row, - 'dicom': dicom_data_row, - 'geospatial': geospatial_data_row, - 'html': html_data_row, - 'image': image_data_row, - 'document': document_data_row, - 'text': text_data_row, - 'video': video_data_row, - 'llmpromptcreation': llm_prompt_creation_data_row, - 'llmpromptresponsecreation': llm_prompt_response_data_row, - 'llmresponsecreation': text_data_row - } - - -@pytest.fixture -def exports_v2_by_data_type(expected_export_v2_image, expected_export_v2_audio, - expected_export_v2_html, expected_export_v2_text, - expected_export_v2_video, - expected_export_v2_conversation, - expected_export_v2_dicom, - expected_export_v2_document, - expected_export_v2_llm_prompt_creation, - expected_export_v2_llm_prompt_response_creation, - expected_export_v2_llm_response_creation): - return { - 'image': - expected_export_v2_image, - 'audio': - expected_export_v2_audio, - 'html': - expected_export_v2_html, - 'text': - expected_export_v2_text, - 'video': - expected_export_v2_video, - 'conversation': - expected_export_v2_conversation, - 'dicom': - expected_export_v2_dicom, - 'document': - expected_export_v2_document, - 'llmpromptcreation': - expected_export_v2_llm_prompt_creation, - 'llmpromptresponsecreation': - expected_export_v2_llm_prompt_response_creation, - 'llmresponsecreation': - expected_export_v2_llm_response_creation - } - - -@pytest.fixture -def annotations_by_data_type(polygon_inference, rectangle_inference, - rectangle_inference_document, line_inference, - entity_inference, entity_inference_document, - checklist_inference, text_inference, - video_checklist_inference): - return { - 'audio': [checklist_inference, text_inference], - 'conversation': [checklist_inference, text_inference, entity_inference], - 'dicom': [line_inference], - 'document': [ - entity_inference_document, checklist_inference, text_inference, - rectangle_inference_document - ], - 'html': [text_inference, checklist_inference], - 'image': [ - polygon_inference, rectangle_inference, line_inference, - checklist_inference, text_inference - ], - 'text': [entity_inference, checklist_inference, text_inference], - 'video': [video_checklist_inference], - 'llmpromptcreation': [checklist_inference, text_inference], - 'llmpromptresponsecreation': [checklist_inference, text_inference], - 'llmresponsecreation': [checklist_inference, text_inference] - } - - -@pytest.fixture -def annotations_by_data_type_v2( - polygon_inference, rectangle_inference, rectangle_inference_document, - line_inference_v2, line_inference, entity_inference, - entity_inference_index, entity_inference_document, - checklist_inference_index, text_inference_index, checklist_inference, - text_inference, video_checklist_inference): - return { - 'audio': [checklist_inference, text_inference], - 'conversation': [ - checklist_inference_index, text_inference_index, - entity_inference_index - ], - 'dicom': [line_inference_v2], - 'document': [ - entity_inference_document, checklist_inference, text_inference, - rectangle_inference_document - ], - 'html': [text_inference, checklist_inference], - 'image': [ - polygon_inference, rectangle_inference, line_inference, - checklist_inference, text_inference - ], - 'text': [entity_inference, checklist_inference, text_inference], - 'video': [video_checklist_inference], - 'llmpromptcreation': [checklist_inference, text_inference], - 'llmpromptresponsecreation': [checklist_inference, text_inference], - 'llmresponsecreation': [checklist_inference, text_inference] - } - - -@pytest.fixture(scope='session') -def ontology(): - bbox_tool_with_nested_text = { - 'required': - False, - 'name': - 'bbox_tool_with_nested_text', - 'tool': - 'rectangle', - 'color': - '#a23030', - 'classifications': [{ - 'required': - False, - 'instructions': - 'nested', - 'name': - 'nested', - 'type': - 'radio', - 'options': [{ - 'label': - 'radio_option_1', - 'value': - 'radio_value_1', - 'options': [{ - 'required': - False, - 'instructions': - 'nested_checkbox', - 'name': - 'nested_checkbox', - 'type': - 'checklist', - 'options': [{ - 'label': 'nested_checkbox_option_1', - 'value': 'nested_checkbox_value_1', - 'options': [] - }, { - 'label': 'nested_checkbox_option_2', - 'value': 'nested_checkbox_value_2' - }] - }, { - 'required': False, - 'instructions': 'nested_text', - 'name': 'nested_text', - 'type': 'text', - 'options': [] - }] - },] - }] - } - - bbox_tool = { - 'required': - False, - 'name': - 'bbox', - 'tool': - 'rectangle', - 'color': - '#a23030', - 'classifications': [{ - 'required': - False, - 'instructions': - 'nested', - 'name': - 'nested', - 'type': - 'radio', - 'options': [{ - 'label': - 'radio_option_1', - 'value': - 'radio_value_1', - 'options': [{ - 'required': - False, - 'instructions': - 'nested_checkbox', - 'name': - 'nested_checkbox', - 'type': - 'checklist', - 'options': [{ - 'label': 'nested_checkbox_option_1', - 'value': 'nested_checkbox_value_1', - 'options': [] - }, { - 'label': 'nested_checkbox_option_2', - 'value': 'nested_checkbox_value_2' - }] - }] - },] - }] - } - - polygon_tool = { - 'required': False, - 'name': 'polygon', - 'tool': 'polygon', - 'color': '#FF34FF', - 'classifications': [] - } - polyline_tool = { - 'required': False, - 'name': 'polyline', - 'tool': 'line', - 'color': '#FF4A46', - 'classifications': [] - } - point_tool = { - 'required': False, - 'name': 'point--', - 'tool': 'point', - 'color': '#008941', - 'classifications': [] - } - entity_tool = { - 'required': False, - 'name': 'entity--', - 'tool': 'named-entity', - 'color': '#006FA6', - 'classifications': [] - } - segmentation_tool = { - 'required': False, - 'name': 'segmentation--', - 'tool': 'superpixel', - 'color': '#A30059', - 'classifications': [] - } - raster_segmentation_tool = { - 'required': False, - 'name': 'segmentation_mask', - 'tool': 'raster-segmentation', - 'color': '#ff0000', - 'classifications': [] - } - checklist = { - 'required': - False, - 'instructions': - 'checklist', - 'name': - 'checklist', - 'type': - 'checklist', - 'options': [{ - 'label': 'option1', - 'value': 'option1' - }, { - 'label': 'option2', - 'value': 'option2' - }, { - 'label': 'optionN', - 'value': 'optionn' - }] - } - checklist_index = { - 'required': - False, - 'instructions': - 'checklist_index', - 'name': - 'checklist_index', - 'type': - 'checklist', - 'scope': - 'index', - 'options': [{ - 'label': 'option1_index', - 'value': 'option1_index' - }, { - 'label': 'option2_index', - 'value': 'option2_index' - }, { - 'label': 'optionN_index', - 'value': 'optionn_index' - }] - } - free_form_text = { - 'required': False, - 'instructions': 'text', - 'name': 'text', - 'type': 'text', - 'options': [] - } - free_form_text_index = { - 'required': False, - 'instructions': 'text_index', - 'name': 'text_index', - 'type': 'text', - 'scope': 'index', - 'options': [] - } - radio = { - 'required': - False, - 'instructions': - 'radio', - 'name': - 'radio', - 'type': - 'radio', - 'options': [{ - 'label': 'first_radio_answer', - 'value': 'first_radio_answer', - 'options': [] - }, { - 'label': 'second_radio_answer', - 'value': 'second_radio_answer', - 'options': [] - }] - } - named_entity = { - 'tool': 'named-entity', - 'name': 'named-entity', - 'required': False, - 'color': '#A30059', - 'classifications': [], - } - - tools = [ - bbox_tool, - bbox_tool_with_nested_text, - polygon_tool, - polyline_tool, - point_tool, - entity_tool, - segmentation_tool, - raster_segmentation_tool, - named_entity, - ] - classifications = [ - checklist, checklist_index, free_form_text, free_form_text_index, radio - ] - return {"tools": tools, "classifications": classifications} - - -@pytest.fixture -def wait_for_label_processing(): - """ - Do not use. Only for testing. - - Returns project's labels as a list after waiting for them to finish processing. - If `project.labels()` is called before label is fully processed, - it may return an empty set - """ - - def func(project): - timeout_seconds = 10 - while True: - labels = list(project.labels()) - if len(labels) > 0: - return labels - timeout_seconds -= 2 - if timeout_seconds <= 0: - raise TimeoutError( - f"Timed out waiting for label for project '{project.uid}' to finish processing" - ) - time.sleep(2) - - return func - - -@pytest.fixture -def configured_project_datarow_id(configured_project): - - def get_data_row_id(indx=0): - return configured_project.data_row_ids[indx] - - yield get_data_row_id - - -@pytest.fixture -def configured_project_one_datarow_id(configured_project_with_one_data_row): - - def get_data_row_id(indx=0): - return configured_project_with_one_data_row.data_row_ids[0] - - yield get_data_row_id - - -@pytest.fixture -def configured_project(client, initial_dataset, ontology, rand_gen, image_url): - dataset = initial_dataset - project = client.create_project(name=rand_gen(str), - queue_mode=QueueMode.Batch) - editor = list( - client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - project.setup(editor, ontology) - - data_row_ids = [] - - ontologies = ontology['tools'] + ontology['classifications'] - for ind in range(len(ontologies)): - data_row_ids.append( - dataset.create_data_row( - row_data=image_url, - global_key=f"gk_{ontologies[ind]['name']}_{rand_gen(str)}").uid) - project._wait_until_data_rows_are_processed(data_row_ids=data_row_ids, - sleep_interval=3) - - project.create_batch( - rand_gen(str), - data_row_ids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - project.data_row_ids = data_row_ids - - yield project - - project.delete() - - -@pytest.fixture -def project_with_ontology(client, configured_project, ontology, rand_gen): - project = client.create_project(name=rand_gen(str), - queue_mode=QueueMode.Batch, - media_type=MediaType.Image) - editor = list( - client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - project.setup(editor, ontology) - - yield project, ontology - - project.delete() - - -@pytest.fixture -def configured_project_pdf(client, ontology, rand_gen, pdf_url): - project = client.create_project(name=rand_gen(str), - queue_mode=QueueMode.Batch, - media_type=MediaType.Pdf) - dataset = client.create_dataset(name=rand_gen(str)) - editor = list( - client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - project.setup(editor, ontology) - data_row = dataset.create_data_row(pdf_url) - data_row_ids = [data_row.uid] - project.create_batch( - rand_gen(str), - data_row_ids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - project.data_row_ids = data_row_ids - yield project - project.delete() - dataset.delete() - - -@pytest.fixture -def dataset_pdf_entity(client, rand_gen, document_data_row): - dataset = client.create_dataset(name=rand_gen(str)) - data_row_ids = [] - data_row = dataset.create_data_row(document_data_row) - data_row_ids.append(data_row.uid) - yield dataset, data_row_ids - dataset.delete() - - -@pytest.fixture -def dataset_conversation_entity(client, rand_gen, conversation_entity_data_row, - wait_for_data_row_processing): - dataset = client.create_dataset(name=rand_gen(str)) - data_row_ids = [] - data_row = dataset.create_data_row(conversation_entity_data_row) - data_row = wait_for_data_row_processing(client, data_row) - - data_row_ids.append(data_row.uid) - yield dataset, data_row_ids - dataset.delete() - - -@pytest.fixture -def configured_project_with_one_data_row(client, ontology, rand_gen, - initial_dataset, image_url): - project = client.create_project(name=rand_gen(str), - description=rand_gen(str), - queue_mode=QueueMode.Batch) - editor = list( - client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - project.setup(editor, ontology) - - data_row = initial_dataset.create_data_row(row_data=image_url) - data_row_ids = [data_row.uid] - project._wait_until_data_rows_are_processed(data_row_ids=data_row_ids, - sleep_interval=3) - - batch = project.create_batch( - rand_gen(str), - data_row_ids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - project.data_row_ids = data_row_ids - - yield project - - batch.delete() - project.delete() - - -# This function allows to convert an ontology feature to actual annotation -# At the moment it expects only one feature per tool type and this creates unnecessary coupling between differet tests -# In an example of a 'rectangle' we have extended to support multiple instances of the same tool type -# TODO: we will support this approach in the future for all tools -# -""" -Please note that this fixture now offers the flexibility to configure three different strategies for generating data row ids for predictions: -Default(configured_project fixture): - configured_project that generates a data row for each member of ontology. - This makes sure each prediction has its own data row id. This is applicable to prediction upload cases when last label overwrites existing ones - -Optimized Strategy (configured_project_with_one_data_row fixture): - This fixture has only one data row and all predictions will be mapped to it - -Custom Data Row IDs Strategy: - Individuals can supply hard-coded data row ids when a creation of data row is not required. - This particular fixture, termed "hardcoded_datarow_id," should be defined locally within a test file. - In the future, we can use this approach to inject correct number of rows instead of using configured_project fixture - that creates a data row for each member of ontology (14 in total) for each run. -""" - - -@pytest.fixture -def prediction_id_mapping(ontology, request): - # Maps tool types to feature schema ids - if 'configured_project' in request.fixturenames: - data_row_id_factory = request.getfixturevalue( - 'configured_project_datarow_id') - project = request.getfixturevalue('configured_project') - elif 'hardcoded_datarow_id' in request.fixturenames: - data_row_id_factory = request.getfixturevalue('hardcoded_datarow_id') - project = request.getfixturevalue('configured_project_with_ontology') - else: - data_row_id_factory = request.getfixturevalue( - 'configured_project_one_datarow_id') - project = request.getfixturevalue( - 'configured_project_with_one_data_row') - - ontology = project.ontology().normalized - - result = {} - - for idx, tool in enumerate(ontology['tools'] + ontology['classifications']): - if 'tool' in tool: - tool_type = tool['tool'] - else: - tool_type = tool[ - 'type'] if 'scope' not in tool else f"{tool['type']}_{tool['scope']}" # so 'checklist' of 'checklist_index' - - # TODO: remove this once we have a better way to associate multiple tools instances with a single tool type - if tool_type == 'rectangle': - value = { - "uuid": str(uuid.uuid4()), - "schemaId": tool['featureSchemaId'], - "name": tool['name'], - "dataRow": { - "id": data_row_id_factory(idx), - }, - 'tool': tool - } - if tool_type not in result: - result[tool_type] = [] - result[tool_type].append(value) - else: - result[tool_type] = { - "uuid": str(uuid.uuid4()), - "schemaId": tool['featureSchemaId'], - "name": tool['name'], - "dataRow": { - "id": data_row_id_factory(idx), - }, - 'tool': tool - } - return result - - -@pytest.fixture -def polygon_inference(prediction_id_mapping): - polygon = prediction_id_mapping['polygon'].copy() - polygon.update({ - "polygon": [{ - "x": 147.692, - "y": 118.154 - }, { - "x": 142.769, - "y": 104.923 - }, { - "x": 57.846, - "y": 118.769 - }, { - "x": 28.308, - "y": 169.846 - }] - }) - del polygon['tool'] - return polygon - - -def find_tool_by_name(tool_instances, name): - for tool in tool_instances: - if tool['name'] == name: - return tool - return None - - -@pytest.fixture -def rectangle_inference(prediction_id_mapping): - tool_instance = find_tool_by_name(prediction_id_mapping['rectangle'], - 'bbox') - rectangle = tool_instance.copy() - rectangle.update({ - "bbox": { - "top": 48, - "left": 58, - "height": 65, - "width": 12 - }, - 'classifications': [{ - "schemaId": - rectangle['tool']['classifications'][0]['featureSchemaId'], - "name": - rectangle['tool']['classifications'][0]['name'], - "answer": { - "schemaId": - rectangle['tool']['classifications'][0]['options'][0] - ['featureSchemaId'], - "name": - rectangle['tool']['classifications'][0]['options'][0] - ['value'], - "customMetrics": [{ - "name": "customMetric1", - "value": 0.4 - }], - } - }] - }) - del rectangle['tool'] - return rectangle - - -@pytest.fixture -def rectangle_inference_with_confidence(prediction_id_mapping): - tool_instance = find_tool_by_name(prediction_id_mapping['rectangle'], - 'bbox_tool_with_nested_text') - rectangle = tool_instance.copy() - rectangle.update({ - "bbox": { - "top": 48, - "left": 58, - "height": 65, - "width": 12 - }, - 'classifications': [{ - "schemaId": - rectangle['tool']['classifications'][0]['featureSchemaId'], - "name": - rectangle['tool']['classifications'][0]['name'], - "answer": { - "schemaId": - rectangle['tool']['classifications'][0]['options'][0] - ['featureSchemaId'], - "name": - rectangle['tool']['classifications'][0]['options'][0] - ['value'], - "classifications": [{ - "schemaId": - rectangle['tool']['classifications'][0]['options'][0] - ['options'][1]['featureSchemaId'], - "name": - rectangle['tool']['classifications'][0]['options'][0] - ['options'][1]['name'], - "answer": - 'nested answer' - }], - } - }] - }) - - rectangle.update({"confidence": 0.9}) - rectangle["classifications"][0]["answer"]["confidence"] = 0.8 - rectangle["classifications"][0]["answer"]["classifications"][0][ - "confidence"] = 0.7 - - del rectangle['tool'] - return rectangle - - -@pytest.fixture -def rectangle_inference_document(rectangle_inference): - rectangle = rectangle_inference.copy() - rectangle.update({"page": 1, "unit": "POINTS"}) - return rectangle - - -@pytest.fixture -def line_inference(prediction_id_mapping): - line = prediction_id_mapping['line'].copy() - line.update( - {"line": [{ - "x": 147.692, - "y": 118.154 - }, { - "x": 150.692, - "y": 160.154 - }]}) - del line['tool'] - return line - - -@pytest.fixture -def line_inference_v2(prediction_id_mapping): - line = prediction_id_mapping['line'].copy() - line_data = { - "groupKey": - "axial", - "segments": [{ - "keyframes": [{ - "frame": - 1, - "line": [{ - "x": 147.692, - "y": 118.154 - }, { - "x": 150.692, - "y": 160.154 - }] - }] - },] - } - line.update(line_data) - del line['tool'] - return line - - -@pytest.fixture -def point_inference(prediction_id_mapping): - point = prediction_id_mapping['point'].copy() - point.update({"point": {"x": 147.692, "y": 118.154}}) - del point['tool'] - return point - - -@pytest.fixture -def entity_inference(prediction_id_mapping): - entity = prediction_id_mapping['named-entity'].copy() - entity.update({"location": {"start": 67, "end": 128}}) - del entity['tool'] - return entity - - -@pytest.fixture -def entity_inference_index(prediction_id_mapping): - entity = prediction_id_mapping['named-entity'].copy() - entity.update({ - "location": { - "start": 0, - "end": 8 - }, - "messageId": "0", - }) - - del entity['tool'] - return entity - - -@pytest.fixture -def entity_inference_document(prediction_id_mapping): - entity = prediction_id_mapping['named-entity'].copy() - document_selections = { - "textSelections": [{ - "tokenIds": [ - "3f984bf3-1d61-44f5-b59a-9658a2e3440f", - "3bf00b56-ff12-4e52-8cc1-08dbddb3c3b8", - "6e1c3420-d4b7-4c5a-8fd6-ead43bf73d80", - "87a43d32-af76-4a1d-b262-5c5f4d5ace3a", - "e8606e8a-dfd9-4c49-a635-ad5c879c75d0", - "67c7c19e-4654-425d-bf17-2adb8cf02c30", - "149c5e80-3e07-49a7-ab2d-29ddfe6a38fa", - "b0e94071-2187-461e-8e76-96c58738a52c" - ], - "groupId": "2f4336f4-a07e-4e0a-a9e1-5629b03b719b", - "page": 1, - }] - } - entity.update(document_selections) - del entity['tool'] - return entity - - -@pytest.fixture -def segmentation_inference(prediction_id_mapping): - segmentation = prediction_id_mapping['superpixel'].copy() - segmentation.update({ - 'mask': { - "instanceURI": - "https://storage.googleapis.com/labelbox-datasets/image_sample_data/raster_seg.png", - "colorRGB": (255, 255, 255) - } - }) - del segmentation['tool'] - return segmentation - - -@pytest.fixture -def segmentation_inference_rle(prediction_id_mapping): - segmentation = prediction_id_mapping['superpixel'].copy() - segmentation.update({ - 'uuid': str(uuid.uuid4()), - 'mask': { - 'size': [10, 10], - 'counts': [1, 0, 10, 100] - } - }) - del segmentation['tool'] - return segmentation - - -@pytest.fixture -def segmentation_inference_png(prediction_id_mapping): - segmentation = prediction_id_mapping['superpixel'].copy() - segmentation.update({ - 'uuid': str(uuid.uuid4()), - 'mask': { - 'png': "somedata", - } - }) - del segmentation['tool'] - return segmentation - - -@pytest.fixture -def checklist_inference(prediction_id_mapping): - checklist = prediction_id_mapping['checklist'].copy() - checklist.update({ - 'answers': [{ - 'schemaId': checklist['tool']['options'][0]['featureSchemaId'] - }] - }) - del checklist['tool'] - return checklist - - -@pytest.fixture -def checklist_inference_index(prediction_id_mapping): - checklist = prediction_id_mapping['checklist_index'].copy() - checklist.update({ - 'answers': [{ - 'schemaId': checklist['tool']['options'][0]['featureSchemaId'] - }], - "messageId": "0", - }) - del checklist['tool'] - return checklist - - -@pytest.fixture -def text_inference(prediction_id_mapping): - text = prediction_id_mapping['text'].copy() - text.update({'answer': "free form text..."}) - del text['tool'] - return text - - -@pytest.fixture -def text_inference_with_confidence(text_inference): - text = text_inference.copy() - text.update({'confidence': 0.9}) - return text - - -@pytest.fixture -def text_inference_index(prediction_id_mapping): - text = prediction_id_mapping['text_index'].copy() - text.update({'answer': "free form text...", "messageId": "0"}) - del text['tool'] - return text - - -@pytest.fixture -def video_checklist_inference(prediction_id_mapping): - checklist = prediction_id_mapping['checklist'].copy() - checklist.update({ - 'answers': [{ - 'schemaId': checklist['tool']['options'][0]['featureSchemaId'] - }] - }) - - checklist.update( - {"frames": [{ - "start": 7, - "end": 13, - }, { - "start": 18, - "end": 19, - }]}) - del checklist['tool'] - return checklist - - -@pytest.fixture -def model_run_predictions(polygon_inference, rectangle_inference, - line_inference): - # Not supporting mask since there isn't a signed url representing a seg mask to upload - return [polygon_inference, rectangle_inference, line_inference] - - -@pytest.fixture -def object_predictions(polygon_inference, rectangle_inference, line_inference, - entity_inference, segmentation_inference): - return [ - polygon_inference, rectangle_inference, line_inference, - entity_inference, segmentation_inference - ] - - -@pytest.fixture -def object_predictions_for_annotation_import(polygon_inference, - rectangle_inference, - line_inference, - segmentation_inference): - return [ - polygon_inference, rectangle_inference, line_inference, - segmentation_inference - ] - - -@pytest.fixture -def classification_predictions(checklist_inference, text_inference): - return [checklist_inference, text_inference] - - -@pytest.fixture -def predictions(object_predictions, classification_predictions): - return object_predictions + classification_predictions - - -@pytest.fixture -def predictions_with_confidence(text_inference_with_confidence, - rectangle_inference_with_confidence): - return [text_inference_with_confidence, rectangle_inference_with_confidence] - - -@pytest.fixture -def model(client, rand_gen, configured_project): - ontology = configured_project.ontology() - data = {"name": rand_gen(str), "ontology_id": ontology.uid} - model = client.create_model(data["name"], data["ontology_id"]) - yield model - try: - model.delete() - except: - # Already was deleted by the test - pass - - -@pytest.fixture -def model_run(rand_gen, model): - name = rand_gen(str) - model_run = model.create_model_run(name) - yield model_run - try: - model_run.delete() - except: - # Already was deleted by the test - pass - - -@pytest.fixture -def model_run_with_training_metadata(rand_gen, model): - name = rand_gen(str) - training_metadata = {"batch_size": 1000} - model_run = model.create_model_run(name, training_metadata) - yield model_run - try: - model_run.delete() - except: - # Already was deleted by the test - pass - - -@pytest.fixture -def model_run_with_data_rows(client, configured_project, model_run_predictions, - model_run, wait_for_label_processing): - configured_project.enable_model_assisted_labeling() - use_data_row_ids = [p['dataRow']['id'] for p in model_run_predictions] - model_run.upsert_data_rows(use_data_row_ids) - - upload_task = LabelImport.create_from_objects( - client, configured_project.uid, f"label-import-{uuid.uuid4()}", - model_run_predictions) - upload_task.wait_until_done() - assert upload_task.state == AnnotationImportState.FINISHED, "Label Import did not finish" - assert len( - upload_task.errors - ) == 0, f"Label Import {upload_task.name} failed with errors {upload_task.errors}" - labels = wait_for_label_processing(configured_project) - label_ids = [label.uid for label in labels] - model_run.upsert_labels(label_ids) - yield model_run - model_run.delete() - # TODO: Delete resources when that is possible .. - - -@pytest.fixture -def model_run_with_all_project_labels(client, configured_project, - model_run_predictions, model_run, - wait_for_label_processing): - configured_project.enable_model_assisted_labeling() - use_data_row_ids = [p['dataRow']['id'] for p in model_run_predictions] - model_run.upsert_data_rows(use_data_row_ids) - - upload_task = LabelImport.create_from_objects( - client, configured_project.uid, f"label-import-{uuid.uuid4()}", - model_run_predictions) - upload_task.wait_until_done() - assert upload_task.state == AnnotationImportState.FINISHED, "Label Import did not finish" - assert len( - upload_task.errors - ) == 0, f"Label Import {upload_task.name} failed with errors {upload_task.errors}" - wait_for_label_processing(configured_project) - model_run.upsert_labels(project_id=configured_project.uid) - yield model_run - model_run.delete() - # TODO: Delete resources when that is possible .. - - -class AnnotationImportTestHelpers: - - @classmethod - def assert_file_content(cls, url: str, predictions): - response = requests.get(url) - predictions = cls._convert_to_plain_object(predictions) - assert parser.loads(response.text) == predictions - - @staticmethod - def check_running_state(req, name, url=None): - assert req.name == name - if url is not None: - assert req.input_file_url == url - assert req.error_file_url is None - assert req.status_file_url is None - assert req.state == AnnotationImportState.RUNNING - - @staticmethod - def download_and_assert_status(status_file_url): - response = requests.get(status_file_url) - assert response.status_code == 200 - for line in parser.loads(response.content): - status = line['status'] - assert status.upper() == 'SUCCESS' - - @staticmethod - def _convert_to_plain_object(obj): - """Some Python objects e.g. tuples can't be compared with JSON serialized data, serialize to JSON and deserialize to get plain objects""" - json_str = parser.dumps(obj) - return parser.loads(json_str) - - -@pytest.fixture -def annotation_import_test_helpers() -> Type[AnnotationImportTestHelpers]: - return AnnotationImportTestHelpers() - ----- -tests/integration/annotation_import/test_mea_prediction_import.py -import uuid -from labelbox import parser -import pytest - -from labelbox.schema.annotation_import import AnnotationImportState, MEAPredictionImport -from labelbox.data.serialization import NDJsonConverter -from labelbox.schema.export_params import ModelRunExportParams -""" -- Here we only want to check that the uploads are calling the validation -- Then with unit tests we can check the types of errors raised - -""" - - -def test_create_from_objects(model_run_with_data_rows, - object_predictions_for_annotation_import, - annotation_import_test_helpers): - name = str(uuid.uuid4()) - object_predictions = object_predictions_for_annotation_import - use_data_row_ids = [p['dataRow']['id'] for p in object_predictions] - model_run_with_data_rows.upsert_data_rows(use_data_row_ids) - - annotation_import = model_run_with_data_rows.add_predictions( - name=name, predictions=object_predictions) - - assert annotation_import.model_run_id == model_run_with_data_rows.uid - annotation_import_test_helpers.check_running_state(annotation_import, name) - annotation_import_test_helpers.assert_file_content( - annotation_import.input_file_url, object_predictions) - annotation_import.wait_until_done() - - assert annotation_import.state == AnnotationImportState.FINISHED - annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) - - -def test_create_from_objects_global_key(client, model_run_with_data_rows, - polygon_inference, - annotation_import_test_helpers): - name = str(uuid.uuid4()) - dr = client.get_data_row(polygon_inference['dataRow']['id']) - del polygon_inference['dataRow']['id'] - polygon_inference['dataRow']['globalKey'] = dr.global_key - object_predictions = [polygon_inference] - - annotation_import = model_run_with_data_rows.add_predictions( - name=name, predictions=object_predictions) - - assert annotation_import.model_run_id == model_run_with_data_rows.uid - annotation_import_test_helpers.check_running_state(annotation_import, name) - annotation_import_test_helpers.assert_file_content( - annotation_import.input_file_url, object_predictions) - annotation_import.wait_until_done() - - assert annotation_import.state == AnnotationImportState.FINISHED - annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) - - -def test_create_from_objects_with_confidence(predictions_with_confidence, - model_run_with_data_rows, - annotation_import_test_helpers): - name = str(uuid.uuid4()) - - object_prediction_data_rows = [ - object_prediction["dataRow"]["id"] - for object_prediction in predictions_with_confidence - ] - # MUST have all data rows in the model run - model_run_with_data_rows.upsert_data_rows( - data_row_ids=object_prediction_data_rows) - - annotation_import = model_run_with_data_rows.add_predictions( - name=name, predictions=predictions_with_confidence) - - assert annotation_import.model_run_id == model_run_with_data_rows.uid - annotation_import_test_helpers.check_running_state(annotation_import, name) - annotation_import_test_helpers.assert_file_content( - annotation_import.input_file_url, predictions_with_confidence) - annotation_import.wait_until_done() - - assert annotation_import.state == AnnotationImportState.FINISHED - annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) - - -def test_create_from_objects_all_project_labels( - model_run_with_all_project_labels, - object_predictions_for_annotation_import, - annotation_import_test_helpers): - name = str(uuid.uuid4()) - object_predictions = object_predictions_for_annotation_import - use_data_row_ids = [p['dataRow']['id'] for p in object_predictions] - model_run_with_all_project_labels.upsert_data_rows(use_data_row_ids) - - annotation_import = model_run_with_all_project_labels.add_predictions( - name=name, predictions=object_predictions) - - assert annotation_import.model_run_id == model_run_with_all_project_labels.uid - annotation_import_test_helpers.check_running_state(annotation_import, name) - annotation_import_test_helpers.assert_file_content( - annotation_import.input_file_url, object_predictions) - annotation_import.wait_until_done() - - assert annotation_import.state == AnnotationImportState.FINISHED - annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) - - -def test_model_run_project_labels(model_run_with_all_project_labels, - model_run_predictions): - model_run = model_run_with_all_project_labels - # TODO: Move to export_v2 - model_run_exported_labels = model_run.export_labels(download=True) - labels_indexed_by_schema_id = {} - - for label in model_run_exported_labels: - # assuming exported array of label 'objects' has only one label per data row... as usually is when there are no label revisions - schema_id = label['Label']['objects'][0]['schemaId'] - labels_indexed_by_schema_id[schema_id] = label - - assert (len( - labels_indexed_by_schema_id.keys())) == len(model_run_predictions) - - # making sure the labels are in this model run are all labels uploaded to the project - # by comparing some 'immutable' attributes - for expected_label in model_run_predictions: - schema_id = expected_label['schemaId'] - actual_label = labels_indexed_by_schema_id[schema_id] - assert actual_label['Label']['objects'][0]['title'] == expected_label[ - 'name'] - assert actual_label['DataRow ID'] == expected_label['dataRow']['id'] - - -def test_create_from_label_objects(model_run_with_data_rows, - object_predictions_for_annotation_import, - annotation_import_test_helpers): - name = str(uuid.uuid4()) - use_data_row_ids = [ - p['dataRow']['id'] for p in object_predictions_for_annotation_import - ] - model_run_with_data_rows.upsert_data_rows(use_data_row_ids) - - predictions = list( - NDJsonConverter.deserialize(object_predictions_for_annotation_import)) - - annotation_import = model_run_with_data_rows.add_predictions( - name=name, predictions=predictions) - - assert annotation_import.model_run_id == model_run_with_data_rows.uid - annotation_import_test_helpers.check_running_state(annotation_import, name) - normalized_predictions = NDJsonConverter.serialize(predictions) - annotation_import_test_helpers.assert_file_content( - annotation_import.input_file_url, normalized_predictions) - annotation_import.wait_until_done() - - assert annotation_import.state == AnnotationImportState.FINISHED - annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) - - -def test_create_from_local_file(tmp_path, model_run_with_data_rows, - object_predictions_for_annotation_import, - annotation_import_test_helpers): - use_data_row_ids = [ - p['dataRow']['id'] for p in object_predictions_for_annotation_import - ] - model_run_with_data_rows.upsert_data_rows(use_data_row_ids) - - name = str(uuid.uuid4()) - file_name = f"{name}.ndjson" - file_path = tmp_path / file_name - with file_path.open("w") as f: - parser.dump(object_predictions_for_annotation_import, f) - - annotation_import = model_run_with_data_rows.add_predictions( - name=name, predictions=str(file_path)) - - assert annotation_import.model_run_id == model_run_with_data_rows.uid - annotation_import_test_helpers.check_running_state(annotation_import, name) - annotation_import_test_helpers.assert_file_content( - annotation_import.input_file_url, - object_predictions_for_annotation_import) - annotation_import.wait_until_done() - - assert annotation_import.state == AnnotationImportState.FINISHED - annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) - - -def test_predictions_with_custom_metrics( - model_run, object_predictions_for_annotation_import, - annotation_import_test_helpers): - name = str(uuid.uuid4()) - object_predictions = object_predictions_for_annotation_import - use_data_row_ids = [p['dataRow']['id'] for p in object_predictions] - model_run.upsert_data_rows(use_data_row_ids) - - annotation_import = model_run.add_predictions( - name=name, predictions=object_predictions) - - assert annotation_import.model_run_id == model_run.uid - annotation_import.wait_until_done() - assert annotation_import.state == AnnotationImportState.FINISHED - - task = model_run.export_v2(params=ModelRunExportParams(predictions=True)) - task.wait_till_done() - - assert annotation_import.state == AnnotationImportState.FINISHED - annotation_import_test_helpers.download_and_assert_status( - annotation_import.status_file_url) - - -def test_get(client, model_run_with_data_rows, annotation_import_test_helpers): - name = str(uuid.uuid4()) - url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" - model_run_with_data_rows.add_predictions(name=name, predictions=url) - - annotation_import = MEAPredictionImport.from_name( - client, model_run_id=model_run_with_data_rows.uid, name=name) - - assert annotation_import.model_run_id == model_run_with_data_rows.uid - annotation_import_test_helpers.check_running_state(annotation_import, name, - url) - annotation_import.wait_until_done() - - -@pytest.mark.slow -def test_wait_till_done(model_run_predictions, model_run_with_data_rows): - name = str(uuid.uuid4()) - annotation_import = model_run_with_data_rows.add_predictions( - name=name, predictions=model_run_predictions) - - assert len(annotation_import.inputs) == len(model_run_predictions) - annotation_import.wait_until_done() - assert annotation_import.state == AnnotationImportState.FINISHED - # Check that the status files are being returned as expected - assert len(annotation_import.errors) == 0 - assert len(annotation_import.inputs) == len(model_run_predictions) - input_uuids = [ - input_annot['uuid'] for input_annot in annotation_import.inputs - ] - inference_uuids = [pred['uuid'] for pred in model_run_predictions] - assert set(input_uuids) == set(inference_uuids) - assert len(annotation_import.statuses) == len(model_run_predictions) - for status in annotation_import.statuses: - assert status['status'] == 'SUCCESS' - status_uuids = [ - input_annot['uuid'] for input_annot in annotation_import.statuses - ] - assert set(input_uuids) == set(status_uuids) - ----- -tests/integration/annotation_import/test_send_to_annotate_mea.py -import pytest - -from labelbox import UniqueIds, OntologyBuilder -from labelbox.schema.conflict_resolution_strategy import ConflictResolutionStrategy - - -def test_send_to_annotate_from_model(client, configured_project, - model_run_predictions, - model_run_with_data_rows, project): - model_run = model_run_with_data_rows - data_row_ids = [p['dataRow']['id'] for p in model_run_predictions] - assert len(data_row_ids) > 0 - - destination_project = project - model = client.get_model(model_run.model_id) - ontology = client.get_ontology(model.ontology_id) - destination_project.setup_editor(ontology) - - queues = destination_project.task_queues() - initial_review_task = next( - q for q in queues if q.name == "Initial review task") - - # build an ontology mapping using the top level tools and classifications - source_ontology_builder = OntologyBuilder.from_project(configured_project) - feature_schema_ids = list( - tool.feature_schema_id for tool in source_ontology_builder.tools) - # create a dictionary of feature schema id to itself - ontology_mapping = dict(zip(feature_schema_ids, feature_schema_ids)) - - classification_feature_schema_ids = list( - classification.feature_schema_id - for classification in source_ontology_builder.classifications) - # create a dictionary of feature schema id to itself - classification_ontology_mapping = dict( - zip(classification_feature_schema_ids, - classification_feature_schema_ids)) - - # combine the two ontology mappings - ontology_mapping.update(classification_ontology_mapping) - - task = model_run.send_to_annotate_from_model( - destination_project_id=destination_project.uid, - batch_name="batch", - data_rows=UniqueIds(data_row_ids), - task_queue_id=initial_review_task.uid, - params={ - "predictions_ontology_mapping": - ontology_mapping, - "override_existing_annotations_rule": - ConflictResolutionStrategy.OverrideWithPredictions - }) - - task.wait_till_done() - - # Check that the data row was sent to the new project - destination_batches = list(destination_project.batches()) - assert len(destination_batches) == 1 - - destination_data_rows = list(destination_batches[0].export_data_rows()) - assert len(destination_data_rows) == len(data_row_ids) - assert all([dr.uid in data_row_ids for dr in destination_data_rows]) - - # Since data rows were added to a review queue, predictions should be imported into the project as labels - destination_project_labels = (list(destination_project.labels())) - assert len(destination_project_labels) == len(data_row_ids) - ----- -tests/integration/annotation_import/test_ndjson_validation.py -from labelbox.schema.media_type import MediaType -import pytest - -from labelbox import parser -from pytest_cases import parametrize, fixture_ref - -from labelbox.exceptions import MALValidationError -from labelbox.schema.bulk_import_request import (NDChecklist, NDClassification, - NDMask, NDPolygon, NDPolyline, - NDRadio, NDRectangle, NDText, - NDTextEntity, NDTool, - _validate_ndjson) -from labelbox.schema.labeling_frontend import LabelingFrontend -from labelbox.schema.queue_mode import QueueMode - - -@pytest.fixture(scope="module", autouse=True) -def hardcoded_datarow_id(): - data_row_id = 'ck8q9q9qj00003g5z3q1q9q9q' - - def get_data_row_id(indx=0): - return data_row_id - - yield get_data_row_id - - -@pytest.fixture(scope="module", autouse=True) -def configured_project_with_ontology(client, ontology, rand_gen): - project = client.create_project( - name=rand_gen(str), - queue_mode=QueueMode.Batch, - media_type=MediaType.Image, - ) - editor = list( - client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - project.setup(editor, ontology) - - yield project - - project.delete() - - -def test_classification_construction(checklist_inference, text_inference): - checklist = NDClassification.build(checklist_inference) - assert isinstance(checklist, NDChecklist) - text = NDClassification.build(text_inference) - assert isinstance(text, NDText) - - -def test_subclassification_construction(rectangle_inference): - tool = NDTool.build(rectangle_inference) - assert len(tool.classifications) == 1, "Subclass was not constructed" - assert isinstance(tool.classifications[0], NDRadio) - - -@parametrize("inference, expected_type", - [(fixture_ref('polygon_inference'), NDPolygon), - (fixture_ref('rectangle_inference'), NDRectangle), - (fixture_ref('line_inference'), NDPolyline), - (fixture_ref('entity_inference'), NDTextEntity), - (fixture_ref('segmentation_inference'), NDMask), - (fixture_ref('segmentation_inference_rle'), NDMask), - (fixture_ref('segmentation_inference_png'), NDMask)]) -def test_tool_construction(inference, expected_type): - assert isinstance(NDTool.build(inference), expected_type) - - -def test_incorrect_feature_schema(rectangle_inference, polygon_inference, - configured_project_with_ontology): - #Valid but incorrect feature schema - #Prob the error message says something about the config not anything useful. We might want to fix this. - pred = rectangle_inference.copy() - pred['schemaId'] = polygon_inference['schemaId'] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) - - -def no_tool(text_inference, configured_project_with_ontology): - pred = text_inference.copy() - #Missing key - del pred['answer'] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) - - -def test_invalid_text(text_inference, configured_project_with_ontology): - #and if it is not a string - pred = text_inference.copy() - #Extra and wrong key - del pred['answer'] - pred['answers'] = [] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) - del pred['answers'] - - #Invalid type - pred['answer'] = [] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) - - #Invalid type - pred['answer'] = None - with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) - - -def test_invalid_checklist_item(checklist_inference, - configured_project_with_ontology): - #Only two points - pred = checklist_inference.copy() - pred['answers'] = [pred['answers'][0], pred['answers'][0]] - #Duplicate schema ids - with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) - - pred['answers'] = [{"name": "asdfg"}] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) - - pred['answers'] = [{"schemaId": "1232132132"}] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) - - pred['answers'] = [{}] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) - - pred['answers'] = [] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) - - del pred['answers'] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) - - -def test_invalid_polygon(polygon_inference, configured_project_with_ontology): - #Only two points - pred = polygon_inference.copy() - pred['polygon'] = [{"x": 100, "y": 100}, {"x": 200, "y": 200}] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) - - -def test_incorrect_entity(entity_inference, configured_project_with_ontology): - entity = entity_inference.copy() - #Location cannot be a list - entity["location"] = [0, 10] - with pytest.raises(MALValidationError): - _validate_ndjson([entity], configured_project_with_ontology) - - entity["location"] = {"start": -1, "end": 5} - with pytest.raises(MALValidationError): - _validate_ndjson([entity], configured_project_with_ontology) - - entity["location"] = {"start": 15, "end": 5} - with pytest.raises(MALValidationError): - _validate_ndjson([entity], configured_project_with_ontology) - - -def test_incorrect_mask(segmentation_inference, - configured_project_with_ontology): - seg = segmentation_inference.copy() - seg['mask']['colorRGB'] = [-1, 0, 10] - with pytest.raises(MALValidationError): - _validate_ndjson([seg], configured_project_with_ontology) - - seg['mask']['colorRGB'] = [0, 0] - with pytest.raises(MALValidationError): - _validate_ndjson([seg], configured_project_with_ontology) - - seg['mask'] = {'counts': [0], 'size': [0, 1]} - with pytest.raises(MALValidationError): - _validate_ndjson([seg], configured_project_with_ontology) - - seg['mask'] = {'counts': [-1], 'size': [1, 1]} - with pytest.raises(MALValidationError): - _validate_ndjson([seg], configured_project_with_ontology) - - -def test_all_validate_json(configured_project_with_ontology, predictions): - #Predictions contains one of each type of prediction. - #These should be properly formatted and pass. - _validate_ndjson(predictions, configured_project_with_ontology) - - -def test_incorrect_line(line_inference, configured_project_with_ontology): - line = line_inference.copy() - line["line"] = [line["line"][0]] #Just one point - with pytest.raises(MALValidationError): - _validate_ndjson([line], configured_project_with_ontology) - - -def test_incorrect_rectangle(rectangle_inference, - configured_project_with_ontology): - del rectangle_inference['bbox']['top'] - with pytest.raises(MALValidationError): - _validate_ndjson([rectangle_inference], - configured_project_with_ontology) - - -def test_duplicate_tools(rectangle_inference, configured_project_with_ontology): - #Trying to upload a polygon and rectangle at the same time - pred = rectangle_inference.copy() - pred['polygon'] = [{"x": 100, "y": 100}, {"x": 200, "y": 200}] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) - - -def test_invalid_feature_schema(configured_project_with_ontology, - rectangle_inference): - #Trying to upload a polygon and rectangle at the same time - pred = rectangle_inference.copy() - pred['schemaId'] = "blahblah" - with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) - - -def test_name_only_feature_schema(configured_project_with_ontology, - rectangle_inference): - #Trying to upload a polygon and rectangle at the same time - pred = rectangle_inference.copy() - del pred['schemaId'] - _validate_ndjson([pred], configured_project_with_ontology) - - -def test_schema_id_only_feature_schema(configured_project_with_ontology, - rectangle_inference): - #Trying to upload a polygon and rectangle at the same time - pred = rectangle_inference.copy() - del pred['name'] - _validate_ndjson([pred], configured_project_with_ontology) - - -def test_missing_feature_schema(configured_project_with_ontology, - rectangle_inference): - #Trying to upload a polygon and rectangle at the same time - pred = rectangle_inference.copy() - del pred['schemaId'] - del pred['name'] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project_with_ontology) - - -def test_validate_ndjson(tmp_path, configured_project_with_ontology): - file_name = f"broken.ndjson" - file_path = tmp_path / file_name - with file_path.open("w") as f: - f.write("test") - - with pytest.raises(ValueError): - configured_project_with_ontology.upload_annotations( - name="name", annotations=str(file_path), validate=True) - - -def test_validate_ndjson_uuid(tmp_path, configured_project_with_ontology, - predictions): - file_name = f"repeat_uuid.ndjson" - file_path = tmp_path / file_name - repeat_uuid = predictions.copy() - repeat_uuid[0]['uuid'] = 'test_uuid' - repeat_uuid[1]['uuid'] = 'test_uuid' - - with file_path.open("w") as f: - parser.dump(repeat_uuid, f) - - with pytest.raises(MALValidationError): - configured_project_with_ontology.upload_annotations( - name="name", validate=True, annotations=str(file_path)) - - with pytest.raises(MALValidationError): - configured_project_with_ontology.upload_annotations( - name="name", validate=True, annotations=repeat_uuid) - - -def test_video_upload(video_checklist_inference, - configured_project_with_ontology): - pred = video_checklist_inference.copy() - _validate_ndjson([pred], configured_project_with_ontology) - ----- -tests/integration/annotation_import/test_upsert_prediction_import.py -import uuid -from labelbox import parser -import pytest -""" -- Here we only want to check that the uploads are calling the validation -- Then with unit tests we can check the types of errors raised - -""" - - -@pytest.mark.skip() -def test_create_from_url(client, tmp_path, object_predictions, - model_run_with_data_rows, - configured_project_with_one_data_row, - annotation_import_test_helpers): - name = str(uuid.uuid4()) - file_name = f"{name}.json" - file_path = tmp_path / file_name - - model_run_data_rows = [ - mrdr.data_row().uid - for mrdr in model_run_with_data_rows.model_run_data_rows() - ] - predictions = [ - p for p in object_predictions - if p['dataRow']['id'] in model_run_data_rows - ] - with file_path.open("w") as f: - parser.dump(predictions, f) - - # Needs to have data row ids - - with open(file_path, "r") as f: - url = client.upload_data(content=f.read(), - filename=file_name, - sign=True, - content_type="application/json") - - annotation_import, batch, mal_prediction_import = model_run_with_data_rows.upsert_predictions_and_send_to_project( - name=name, - predictions=url, - project_id=configured_project_with_one_data_row.uid, - priority=5) - - assert annotation_import.model_run_id == model_run_with_data_rows.uid - annotation_import.wait_until_done() - assert not annotation_import.errors - assert annotation_import.statuses - - assert batch - assert batch.project().uid == configured_project_with_one_data_row.uid - - assert mal_prediction_import - mal_prediction_import.wait_until_done() - - assert not mal_prediction_import.errors - assert mal_prediction_import.statuses - - -@pytest.mark.skip() -def test_create_from_objects(model_run_with_data_rows, - configured_project_with_one_data_row, - object_predictions, - annotation_import_test_helpers): - name = str(uuid.uuid4()) - model_run_data_rows = [ - mrdr.data_row().uid - for mrdr in model_run_with_data_rows.model_run_data_rows() - ] - predictions = [ - p for p in object_predictions - if p['dataRow']['id'] in model_run_data_rows - ] - annotation_import, batch, mal_prediction_import = model_run_with_data_rows.upsert_predictions_and_send_to_project( - name=name, - predictions=predictions, - project_id=configured_project_with_one_data_row.uid, - priority=5) - - assert annotation_import.model_run_id == model_run_with_data_rows.uid - annotation_import.wait_until_done() - assert not annotation_import.errors - assert annotation_import.statuses - - assert batch - assert batch.project().uid == configured_project_with_one_data_row.uid - - assert mal_prediction_import - mal_prediction_import.wait_until_done() - - assert not mal_prediction_import.errors - assert mal_prediction_import.statuses - - -@pytest.mark.skip() -def test_create_from_local_file(tmp_path, model_run_with_data_rows, - configured_project_with_one_data_row, - object_predictions, - annotation_import_test_helpers): - - name = str(uuid.uuid4()) - file_name = f"{name}.ndjson" - file_path = tmp_path / file_name - - model_run_data_rows = [ - mrdr.data_row().uid - for mrdr in model_run_with_data_rows.model_run_data_rows() - ] - predictions = [ - p for p in object_predictions - if p['dataRow']['id'] in model_run_data_rows - ] - - with file_path.open("w") as f: - parser.dump(predictions, f) - - annotation_import, batch, mal_prediction_import = model_run_with_data_rows.upsert_predictions_and_send_to_project( - name=name, - predictions=str(file_path), - project_id=configured_project_with_one_data_row.uid, - priority=5) - - assert annotation_import.model_run_id == model_run_with_data_rows.uid - annotation_import.wait_until_done() - assert not annotation_import.errors - assert annotation_import.statuses - - assert batch - assert batch.project().uid == configured_project_with_one_data_row.uid - - assert mal_prediction_import - mal_prediction_import.wait_until_done() - - assert not mal_prediction_import.errors - assert mal_prediction_import.statuses - ----- -tests/integration/annotation_import/test_conversation_import.py -import uuid -from labelbox.data.annotation_types.annotation import ObjectAnnotation -from labelbox.data.annotation_types.label import Label -from labelbox.data.annotation_types.data.text import TextData -from labelbox.data.annotation_types.ner import ConversationEntity - -from labelbox.schema.annotation_import import MALPredictionImport - - -def test_conversation_entity(client, configured_project_with_one_data_row, - dataset_conversation_entity, rand_gen): - - conversation_entity_annotation = ConversationEntity(start=0, - end=8, - message_id="4") - - entities_annotation = ObjectAnnotation(name="named-entity", - value=conversation_entity_annotation) - - labels = [] - _, data_row_uids = dataset_conversation_entity - - configured_project_with_one_data_row.create_batch( - rand_gen(str), - data_row_uids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - - for data_row_uid in data_row_uids: - labels.append( - Label(data=TextData(uid=data_row_uid), - annotations=[ - entities_annotation, - ])) - - import_annotations = MALPredictionImport.create_from_objects( - client=client, - project_id=configured_project_with_one_data_row.uid, - name=f"import {str(uuid.uuid4())}", - predictions=labels) - - import_annotations.wait_until_done() - - assert import_annotations.errors == [] - ----- -tests/integration/annotation_import/test_data_types.py -import datetime -import itertools -import pytest -import uuid - -import labelbox as lb -from labelbox.data.annotation_types.data.video import VideoData -from labelbox.schema.data_row import DataRow -from labelbox.schema.media_type import MediaType -import labelbox.types as lb_types -from labelbox.data.annotation_types.data import AudioData, ConversationData, DicomData, DocumentData, HTMLData, ImageData, TextData, LlmPromptCreationData, LlmPromptResponseCreationData, LlmResponseCreationData -from labelbox.data.serialization import NDJsonConverter -from labelbox.schema.annotation_import import AnnotationImportState -from utils import remove_keys_recursive, rename_cuid_key_recursive - -DATA_ROW_PROCESSING_WAIT_TIMEOUT_SECONDS = 40 -DATA_ROW_PROCESSING_WAIT_SLEEP_INTERNAL_SECONDS = 7 - -radio_annotation = lb_types.ClassificationAnnotation( - name="radio", - value=lb_types.Radio(answer=lb_types.ClassificationAnswer( - name="second_radio_answer"))) -checklist_annotation = lb_types.ClassificationAnnotation( - name="checklist", - value=lb_types.Checklist(answer=[ - lb_types.ClassificationAnswer(name="option1"), - lb_types.ClassificationAnswer(name="option2") - ])) -text_annotation = lb_types.ClassificationAnnotation( - name="text", value=lb_types.Text(answer="sample text")) - -video_mask_annotation = lb_types.VideoMaskAnnotation(frames=[ - lb_types.MaskFrame( - index=10, - instance_uri= - "https://storage.googleapis.com/labelbox-datasets/video-sample-data/mask_example.png" - ) -], - instances=[ - lb_types.MaskInstance( - color_rgb=(255, - 255, - 255), - name= - "segmentation_mask" - ) - ]) - -test_params = [[ - 'html', lb_types.HTMLData, - [radio_annotation, checklist_annotation, text_annotation] -], - [ - 'audio', lb_types.AudioData, - [radio_annotation, checklist_annotation, text_annotation] - ], ['video', lb_types.VideoData, [video_mask_annotation]]] - - -def get_annotation_comparison_dicts_from_labels(labels): - labels_ndjson = list(NDJsonConverter.serialize(labels)) - for annotation in labels_ndjson: - annotation.pop('uuid', None) - annotation.pop('dataRow') - - if 'masks' in annotation: - for frame in annotation['masks']['frames']: - frame.pop('instanceURI') - frame.pop('imBytes') - for instance in annotation['masks']['instances']: - instance.pop('colorRGB') - return labels_ndjson - - -def get_annotation_comparison_dicts_from_export(export_result, data_row_id, - project_id): - exported_data_row = [ - dr for dr in export_result if dr['data_row']['id'] == data_row_id - ][0] - exported_label = exported_data_row['projects'][project_id]['labels'][0] - exported_annotations = exported_label['annotations'] - converted_annotations = [] - if exported_label['label_kind'] == 'Video': - frames = [] - instances = [] - for frame_id, frame in exported_annotations['frames'].items(): - frames.append({'index': int(frame_id)}) - for object in frame['objects'].values(): - instances.append({'name': object['name']}) - converted_annotations.append( - {'masks': { - 'frames': frames, - 'instances': instances, - }}) - else: - exported_annotations = list( - itertools.chain(*exported_annotations.values())) - for annotation in exported_annotations: - if annotation['name'] == 'radio': - converted_annotations.append({ - 'name': annotation['name'], - 'answer': { - 'name': annotation['radio_answer']['name'] - } - }) - elif annotation['name'] == 'checklist': - converted_annotations.append({ - 'name': - annotation['name'], - 'answer': [{ - 'name': answer['name'] - } for answer in annotation['checklist_answers']] - }) - elif annotation['name'] == 'text': - converted_annotations.append({ - 'name': annotation['name'], - 'answer': annotation['text_answer']['content'] - }) - return converted_annotations - - -def create_data_row_for_project(project, dataset, data_row_ndjson, batch_name): - data_row = dataset.create_data_row(data_row_ndjson) - - project.create_batch( - batch_name, - [data_row.uid], # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - project.data_row_ids.append(data_row.uid) - - return data_row - - -# TODO: Add VideoData. Currently label import job finishes without errors but project.export_labels() returns empty list. -@pytest.mark.parametrize('data_type_class', [ - AudioData, ConversationData, DicomData, DocumentData, HTMLData, ImageData, - TextData, LlmPromptCreationData, LlmPromptResponseCreationData, - LlmResponseCreationData -]) -def test_import_data_types( - client, - configured_project, - initial_dataset, - rand_gen, - data_row_json_by_data_type, - annotations_by_data_type, - data_type_class, -): - - project = configured_project - project_id = project.uid - dataset = initial_dataset - - set_project_media_type_from_data_type(project, data_type_class) - - data_type_string = data_type_class.__name__[:-4].lower() - data_row_ndjson = data_row_json_by_data_type[data_type_string] - data_row = create_data_row_for_project(project, dataset, data_row_ndjson, - rand_gen(str)) - - annotations_ndjson = annotations_by_data_type[data_type_string] - annotations_list = [ - label.annotations - for label in NDJsonConverter.deserialize(annotations_ndjson) - ] - labels = [ - lb_types.Label(data=data_type_class(uid=data_row.uid), - annotations=annotations) - for annotations in annotations_list - ] - - label_import = lb.LabelImport.create_from_objects( - client, project_id, f'test-import-{data_type_string}', labels) - label_import.wait_until_done() - - assert label_import.state == AnnotationImportState.FINISHED - assert len(label_import.errors) == 0 - exported_labels = project.export_labels(download=True) - objects = exported_labels[0]['Label']['objects'] - classifications = exported_labels[0]['Label']['classifications'] - assert len(objects) + len(classifications) == len(labels) - data_row.delete() - - -def test_import_data_types_by_global_key( - client, - configured_project, - initial_dataset, - rand_gen, - data_row_json_by_data_type, - annotations_by_data_type, -): - - project = configured_project - project_id = project.uid - dataset = initial_dataset - data_type_class = ImageData - set_project_media_type_from_data_type(project, data_type_class) - - data_row_ndjson = data_row_json_by_data_type['image'] - data_row_ndjson['global_key'] = str(uuid.uuid4()) - data_row = create_data_row_for_project(project, dataset, data_row_ndjson, - rand_gen(str)) - - annotations_ndjson = annotations_by_data_type['image'] - annotations_list = [ - label.annotations - for label in NDJsonConverter.deserialize(annotations_ndjson) - ] - labels = [ - lb_types.Label(data=data_type_class(global_key=data_row.global_key), - annotations=annotations) - for annotations in annotations_list - ] - - label_import = lb.LabelImport.create_from_objects(client, project_id, - f'test-import-image', - labels) - label_import.wait_until_done() - - assert label_import.state == AnnotationImportState.FINISHED - assert len(label_import.errors) == 0 - exported_labels = project.export_labels(download=True) - objects = exported_labels[0]['Label']['objects'] - classifications = exported_labels[0]['Label']['classifications'] - assert len(objects) + len(classifications) == len(labels) - data_row.delete() - - -def validate_iso_format(date_string: str): - parsed_t = datetime.datetime.fromisoformat( - date_string) #this will blow up if the string is not in iso format - assert parsed_t.hour is not None - assert parsed_t.minute is not None - assert parsed_t.second is not None - - -def to_pascal_case(name: str) -> str: - return "".join([word.capitalize() for word in name.split("_")]) - - -def set_project_media_type_from_data_type(project, data_type_class): - data_type_string = data_type_class.__name__[:-4].lower() - media_type = to_pascal_case(data_type_string) - if media_type == 'Conversation': - media_type = 'Conversational' - elif media_type == 'Llmpromptcreation': - media_type = 'LLMPromptCreation' - elif media_type == 'Llmpromptresponsecreation': - media_type = 'LLMPromptResponseCreation' - elif media_type == 'Llmresponsecreation': - media_type = 'Text' - project.update(media_type=MediaType[media_type]) - - -@pytest.mark.parametrize('data_type_class', [ - AudioData, HTMLData, ImageData, TextData, VideoData, ConversationData, - DocumentData, DicomData, LlmPromptCreationData, - LlmPromptResponseCreationData, LlmResponseCreationData -]) -def test_import_data_types_v2(client, configured_project, initial_dataset, - data_row_json_by_data_type, - annotations_by_data_type_v2, data_type_class, - exports_v2_by_data_type, export_v2_test_helpers, - rand_gen): - - project = configured_project - dataset = initial_dataset - project_id = project.uid - - set_project_media_type_from_data_type(project, data_type_class) - - data_type_string = data_type_class.__name__[:-4].lower() - data_row_ndjson = data_row_json_by_data_type[data_type_string] - data_row = create_data_row_for_project(project, dataset, data_row_ndjson, - rand_gen(str)) - annotations_ndjson = annotations_by_data_type_v2[data_type_string] - annotations_list = [ - label.annotations - for label in NDJsonConverter.deserialize(annotations_ndjson) - ] - labels = [ - lb_types.Label(data=data_type_class(uid=data_row.uid), - annotations=annotations) - for annotations in annotations_list - ] - - label_import = lb.LabelImport.create_from_objects( - client, project_id, f'test-import-{data_type_string}', labels) - label_import.wait_until_done() - - assert label_import.state == AnnotationImportState.FINISHED - assert len(label_import.errors) == 0 - - #TODO need to migrate project to the new BATCH mode and change this code - # to be similar to tests/integration/test_task_queue.py - - result = export_v2_test_helpers.run_project_export_v2_task(project) - exported_data = result[0] - - # timestamp fields are in iso format - validate_iso_format(exported_data['data_row']['details']['created_at']) - validate_iso_format(exported_data['data_row']['details']['updated_at']) - validate_iso_format(exported_data['projects'][project_id]['labels'][0] - ['label_details']['created_at']) - validate_iso_format(exported_data['projects'][project_id]['labels'][0] - ['label_details']['updated_at']) - - assert (exported_data['data_row']['id'] == data_row.uid) - exported_project = exported_data['projects'][project_id] - exported_project_labels = exported_project['labels'][0] - exported_annotations = exported_project_labels['annotations'] - - remove_keys_recursive(exported_annotations, - ['feature_id', 'feature_schema_id']) - rename_cuid_key_recursive(exported_annotations) - assert exported_annotations == exports_v2_by_data_type[data_type_string] - - data_row = client.get_data_row(data_row.uid) - data_row.delete() - - -@pytest.mark.parametrize('data_type, data_class, annotations', test_params) -def test_import_label_annotations(client, configured_project_with_one_data_row, - initial_dataset, data_row_json_by_data_type, - data_type, data_class, annotations, rand_gen): - - project = configured_project_with_one_data_row - dataset = initial_dataset - set_project_media_type_from_data_type(project, data_class) - - data_row_json = data_row_json_by_data_type[data_type] - data_row = create_data_row_for_project(project, dataset, data_row_json, - rand_gen(str)) - - labels = [ - lb_types.Label(data=data_class(uid=data_row.uid), - annotations=annotations) - ] - - label_import = lb.LabelImport.create_from_objects(client, project.uid, - f'test-import-html', - labels) - label_import.wait_until_done() - - assert label_import.state == lb.AnnotationImportState.FINISHED - assert len(label_import.errors) == 0 - export_params = { - "attachments": False, - "metadata_fields": False, - "data_row_details": False, - "project_details": False, - "performance_details": False - } - export_task = project.export_v2(params=export_params) - export_task.wait_till_done() - assert export_task.errors is None - expected_annotations = get_annotation_comparison_dicts_from_labels(labels) - actual_annotations = get_annotation_comparison_dicts_from_export( - export_task.result, data_row.uid, - configured_project_with_one_data_row.uid) - assert actual_annotations == expected_annotations - data_row.delete() - - -@pytest.mark.parametrize('data_type, data_class, annotations', test_params) -@pytest.fixture -def one_datarow(client, rand_gen, data_row_json_by_data_type, data_type): - dataset = client.create_dataset(name=rand_gen(str)) - data_row_json = data_row_json_by_data_type[data_type] - data_row = dataset.create_data_row(data_row_json) - - yield data_row - - dataset.delete() - - -@pytest.fixture -def one_datarow_global_key(client, rand_gen, data_row_json_by_data_type): - dataset = client.create_dataset(name=rand_gen(str)) - data_row_json = data_row_json_by_data_type['video'] - data_row = dataset.create_data_row(data_row_json) - - yield data_row - - dataset.delete() - - -@pytest.mark.parametrize('data_type, data_class, annotations', test_params) -def test_import_mal_annotations(client, configured_project_with_one_data_row, - data_type, data_class, annotations, rand_gen, - one_datarow): - data_row = one_datarow - set_project_media_type_from_data_type(configured_project_with_one_data_row, - data_class) - - configured_project_with_one_data_row.create_batch( - rand_gen(str), - [data_row.uid], - ) - - labels = [ - lb_types.Label(data=data_class(uid=data_row.uid), - annotations=annotations) - ] - - import_annotations = lb.MALPredictionImport.create_from_objects( - client=client, - project_id=configured_project_with_one_data_row.uid, - name=f"import {str(uuid.uuid4())}", - predictions=labels) - import_annotations.wait_until_done() - - assert import_annotations.errors == [] - # MAL Labels cannot be exported and compared to input labels - - -def test_import_mal_annotations_global_key(client, - configured_project_with_one_data_row, - rand_gen, one_datarow_global_key): - data_class = lb_types.VideoData - data_row = one_datarow_global_key - annotations = [video_mask_annotation] - set_project_media_type_from_data_type(configured_project_with_one_data_row, - data_class) - - configured_project_with_one_data_row.create_batch( - rand_gen(str), - [data_row.uid], - ) - - labels = [ - lb_types.Label(data=data_class(global_key=data_row.global_key), - annotations=annotations) - ] - - import_annotations = lb.MALPredictionImport.create_from_objects( - client=client, - project_id=configured_project_with_one_data_row.uid, - name=f"import {str(uuid.uuid4())}", - predictions=labels) - import_annotations.wait_until_done() - - assert import_annotations.errors == [] - # MAL Labels cannot be exported and compared to input labels - ----- -tests/integration/annotation_import/test_model_run.py -import time -import os -import pytest - -from collections import Counter - -from labelbox import DataSplit, ModelRun - - -def test_model_run(client, configured_project_with_label, data_row, rand_gen): - project, _, _, label = configured_project_with_label - label_id = label.uid - ontology = project.ontology() - data = {"name": rand_gen(str), "ontology_id": ontology.uid} - model = client.create_model(data["name"], data["ontology_id"]) - - name = rand_gen(str) - config = {"batch_size": 100, "reruns": None} - model_run = model.create_model_run(name, config) - assert model_run.name == name - assert model_run.training_metadata["batchSize"] == config["batch_size"] - assert model_run.training_metadata["reruns"] == config["reruns"] - assert model_run.model_id == model.uid - assert model_run.created_by_id == client.get_user().uid - - model_run.upsert_labels([label_id]) - time.sleep(3) - - model_run_data_row = next(model_run.model_run_data_rows()) - assert model_run_data_row.label_id == label_id - assert model_run_data_row.model_run_id == model_run.uid - assert model_run_data_row.data_row().uid == data_row.uid - - fetch_model_run = client.get_model_run(model_run.uid) - assert fetch_model_run == model_run - - -def test_model_run_no_config(rand_gen, model): - name = rand_gen(str) - model_run = model.create_model_run(name) - assert model_run.name == name - - -def test_model_run_delete(client, model_run): - models_before = list(client.get_models()) - model_before = models_before[0] - before = list(model_before.model_runs()) - - model_run = before[0] - model_run.delete() - - models_after = list(client.get_models()) - model_after = models_after[0] - after = list(model_after.model_runs()) - after_uids = {mr.uid for mr in after} - - assert model_run.uid not in after_uids - - -def test_model_run_update_config(model_run_with_training_metadata): - new_config = {"batch_size": 2000} - res = model_run_with_training_metadata.update_config(new_config) - assert res["trainingMetadata"]["batch_size"] == new_config["batch_size"] - - -def test_model_run_reset_config(model_run_with_training_metadata): - res = model_run_with_training_metadata.reset_config() - assert res["trainingMetadata"] is None - - -def test_model_run_get_config(model_run_with_training_metadata): - new_config = {"batch_size": 2000} - model_run_with_training_metadata.update_config(new_config) - res = model_run_with_training_metadata.get_config() - assert res["batch_size"] == new_config["batch_size"] - - -def test_model_run_data_rows_delete(model_run_with_data_rows): - model_run = model_run_with_data_rows - - before = list(model_run.model_run_data_rows()) - annotation_data_row = before[0] - - data_row_id = annotation_data_row.data_row().uid - model_run.delete_model_run_data_rows(data_row_ids=[data_row_id]) - after = list(model_run.model_run_data_rows()) - assert len(before) == len(after) + 1 - - -def test_model_run_upsert_data_rows(dataset, model_run, - configured_project_with_one_data_row): - n_model_run_data_rows = len(list(model_run.model_run_data_rows())) - assert n_model_run_data_rows == 0 - data_row = dataset.create_data_row(row_data="test row data") - configured_project_with_one_data_row._wait_until_data_rows_are_processed( - data_row_ids=[data_row.uid]) - model_run.upsert_data_rows([data_row.uid]) - n_model_run_data_rows = len(list(model_run.model_run_data_rows())) - assert n_model_run_data_rows == 1 - - -@pytest.mark.parametrize('data_rows', [2], indirect=True) -def test_model_run_upsert_data_rows_using_global_keys(model_run, data_rows): - global_keys = [dr.global_key for dr in data_rows] - assert model_run.upsert_data_rows(global_keys=global_keys) - model_run_data_rows = list(model_run.model_run_data_rows()) - added_data_rows = [mdr.data_row() for mdr in model_run_data_rows] - assert set(added_data_rows) == set(data_rows) - - -def test_model_run_upsert_data_rows_with_existing_labels( - model_run_with_data_rows): - model_run_data_rows = list(model_run_with_data_rows.model_run_data_rows()) - n_data_rows = len(model_run_data_rows) - model_run_with_data_rows.upsert_data_rows([ - model_run_data_row.data_row().uid - for model_run_data_row in model_run_data_rows - ]) - assert n_data_rows == len( - list(model_run_with_data_rows.model_run_data_rows())) - - -def test_model_run_export_labels(model_run_with_data_rows): - labels = model_run_with_data_rows.export_labels(download=True) - assert len(labels) == 3 - - -@pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem", - reason="does not work for onprem") -def test_model_run_status(model_run_with_data_rows): - - def get_model_run_status(): - return model_run_with_data_rows.client.execute( - """query trainingPipelinePyApi($modelRunId: ID!) { - trainingPipeline(where: {id : $modelRunId}) {status, errorMessage, metadata}} - """, {'modelRunId': model_run_with_data_rows.uid}, - experimental=True)['trainingPipeline'] - - model_run_status = get_model_run_status() - assert model_run_status['status'] is None - assert model_run_status['metadata'] is None - assert model_run_status['errorMessage'] is None - - status = "COMPLETE" - metadata = {'key1': 'value1'} - errorMessage = "an error" - model_run_with_data_rows.update_status(status, metadata, errorMessage) - - model_run_status = get_model_run_status() - assert model_run_status['status'] == status - assert model_run_status['metadata'] == metadata - assert model_run_status['errorMessage'] == errorMessage - - extra_metadata = {'key2': 'value2'} - model_run_with_data_rows.update_status(status, extra_metadata) - model_run_status = get_model_run_status() - assert model_run_status['status'] == status - assert model_run_status['metadata'] == {**metadata, **extra_metadata} - assert model_run_status['errorMessage'] == errorMessage - - status = ModelRun.Status.FAILED - model_run_with_data_rows.update_status(status, metadata, errorMessage) - model_run_status = get_model_run_status() - assert model_run_status['status'] == status.value - - with pytest.raises(ValueError): - model_run_with_data_rows.update_status("INVALID", metadata, - errorMessage) - - -def test_model_run_split_assignment_by_data_row_ids( - model_run, dataset, image_url, configured_project_with_one_data_row): - n_data_rows = 2 - data_rows = dataset.create_data_rows([{ - "row_data": image_url - } for _ in range(n_data_rows)]) - data_row_ids = [data_row['id'] for data_row in data_rows.result] - configured_project_with_one_data_row._wait_until_data_rows_are_processed( - data_row_ids=data_row_ids) - model_run.upsert_data_rows(data_row_ids) - - with pytest.raises(ValueError): - model_run.assign_data_rows_to_split(data_row_ids, "INVALID SPLIT") - - for split in ["TRAINING", "TEST", "VALIDATION", "UNASSIGNED", *DataSplit]: - model_run.assign_data_rows_to_split(data_row_ids, split) - counts = Counter() - for data_row in model_run.model_run_data_rows(): - counts[data_row.data_split.value] += 1 - split = split.value if isinstance(split, DataSplit) else split - assert counts[split] == n_data_rows - - -@pytest.mark.parametrize('data_rows', [2], indirect=True) -def test_model_run_split_assignment_by_global_keys(model_run, data_rows): - global_keys = [data_row.global_key for data_row in data_rows] - - model_run.upsert_data_rows(global_keys=global_keys) - - for split in ["TRAINING", "TEST", "VALIDATION", "UNASSIGNED", *DataSplit]: - model_run.assign_data_rows_to_split(split=split, - global_keys=global_keys) - splits = [ - data_row.data_split.value - for data_row in model_run.model_run_data_rows() - ] - assert len(set(splits)) == 1 - ----- -tests/integration/annotation_import/test_label_import.py -import uuid -import pytest - -from labelbox.schema.annotation_import import AnnotationImportState, LabelImport -""" -- Here we only want to check that the uploads are calling the validation -- Then with unit tests we can check the types of errors raised - -""" - - -def test_create_from_url(client, configured_project_with_one_data_row, - annotation_import_test_helpers): - name = str(uuid.uuid4()) - url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" - label_import = LabelImport.create_from_url( - client=client, - project_id=configured_project_with_one_data_row.uid, - name=name, - url=url) - assert label_import.parent_id == configured_project_with_one_data_row.uid - annotation_import_test_helpers.check_running_state(label_import, name, url) - - -def test_create_from_objects(client, configured_project, object_predictions, - annotation_import_test_helpers): - """this test should check running state only to validate running, not completed""" - name = str(uuid.uuid4()) - - label_import = LabelImport.create_from_objects( - client=client, - project_id=configured_project.uid, - name=name, - labels=object_predictions) - - assert label_import.parent_id == configured_project.uid - annotation_import_test_helpers.check_running_state(label_import, name) - annotation_import_test_helpers.assert_file_content( - label_import.input_file_url, object_predictions) - - -# TODO: add me when we add this ability -# def test_create_from_local_file(client, tmp_path, project, -# object_predictions, annotation_import_test_helpers): -# name = str(uuid.uuid4()) -# file_name = f"{name}.ndjson" -# file_path = tmp_path / file_name -# with file_path.open("w") as f: -# ndjson.dump(object_predictions, f) - -# label_import = LabelImport.create_from_url(client=client, project_id=project.uid, name=name, url=str(file_path)) - -# assert label_import.parent_id == project.uid -# annotation_import_test_helpers.check_running_state(label_import, name) -# annotation_import_test_helpers.assert_file_content(label_import.input_file_url, object_predictions) - - -def test_get(client, configured_project_with_one_data_row, - annotation_import_test_helpers): - name = str(uuid.uuid4()) - url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" - label_import = LabelImport.create_from_url( - client=client, - project_id=configured_project_with_one_data_row.uid, - name=name, - url=url) - - assert label_import.parent_id == configured_project_with_one_data_row.uid - annotation_import_test_helpers.check_running_state(label_import, name, url) - - -@pytest.mark.slow -def test_wait_till_done(client, configured_project, predictions): - name = str(uuid.uuid4()) - label_import = LabelImport.create_from_objects( - client=client, - project_id=configured_project.uid, - name=name, - labels=predictions) - - assert len(label_import.inputs) == len(predictions) - label_import.wait_until_done() - - assert label_import.state == AnnotationImportState.FINISHED - assert len(label_import.inputs) == len(predictions) - input_uuids = [input_annot['uuid'] for input_annot in label_import.inputs] - inference_uuids = [pred['uuid'] for pred in predictions] - assert set(input_uuids) == set(inference_uuids) - assert len(label_import.statuses) == len(predictions) - status_uuids = [ - input_annot['uuid'] for input_annot in label_import.statuses - ] - assert set(input_uuids) == set(status_uuids) - ----- -tests/integration/annotation_import/test_bulk_import_request.py -from unittest.mock import patch -import uuid -from labelbox import parser -import pytest -import random -from labelbox.data.annotation_types.annotation import ObjectAnnotation -from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnnotation, ClassificationAnswer, Radio -from labelbox.data.annotation_types.data.video import VideoData -from labelbox.data.annotation_types.geometry.point import Point -from labelbox.data.annotation_types.geometry.rectangle import Rectangle, RectangleUnit -from labelbox.data.annotation_types.label import Label -from labelbox.data.annotation_types.data.text import TextData -from labelbox.data.annotation_types.ner import DocumentEntity, DocumentTextSelection -from labelbox.data.annotation_types.video import VideoObjectAnnotation - -from labelbox.data.serialization import NDJsonConverter -from labelbox.exceptions import MALValidationError, UuidError -from labelbox.schema.bulk_import_request import BulkImportRequest -from labelbox.schema.enums import BulkImportRequestState -from labelbox.schema.annotation_import import LabelImport, MALPredictionImport -from labelbox.schema.media_type import MediaType -""" -- Here we only want to check that the uploads are calling the validation -- Then with unit tests we can check the types of errors raised - -""" - - -def test_create_from_url(project): - name = str(uuid.uuid4()) - url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" - - bulk_import_request = project.upload_annotations(name=name, - annotations=url, - validate=False) - - assert bulk_import_request.project() == project - assert bulk_import_request.name == name - assert bulk_import_request.input_file_url == url - assert bulk_import_request.error_file_url is None - assert bulk_import_request.status_file_url is None - assert bulk_import_request.state == BulkImportRequestState.RUNNING - - -def test_validate_file(project_with_empty_ontology): - name = str(uuid.uuid4()) - url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" - with pytest.raises(MALValidationError): - project_with_empty_ontology.upload_annotations(name=name, - annotations=url, - validate=True) - #Schema ids shouldn't match - - -def test_create_from_objects(configured_project_with_one_data_row, predictions, - annotation_import_test_helpers): - name = str(uuid.uuid4()) - - bulk_import_request = configured_project_with_one_data_row.upload_annotations( - name=name, annotations=predictions) - - assert bulk_import_request.project() == configured_project_with_one_data_row - assert bulk_import_request.name == name - assert bulk_import_request.error_file_url is None - assert bulk_import_request.status_file_url is None - assert bulk_import_request.state == BulkImportRequestState.RUNNING - annotation_import_test_helpers.assert_file_content( - bulk_import_request.input_file_url, predictions) - - -def test_create_from_label_objects(configured_project, predictions, - annotation_import_test_helpers): - name = str(uuid.uuid4()) - - labels = list(NDJsonConverter.deserialize(predictions)) - bulk_import_request = configured_project.upload_annotations( - name=name, annotations=labels) - - assert bulk_import_request.project() == configured_project - assert bulk_import_request.name == name - assert bulk_import_request.error_file_url is None - assert bulk_import_request.status_file_url is None - assert bulk_import_request.state == BulkImportRequestState.RUNNING - normalized_predictions = list(NDJsonConverter.serialize(labels)) - annotation_import_test_helpers.assert_file_content( - bulk_import_request.input_file_url, normalized_predictions) - - -def test_create_from_local_file(tmp_path, predictions, configured_project, - annotation_import_test_helpers): - name = str(uuid.uuid4()) - file_name = f"{name}.ndjson" - file_path = tmp_path / file_name - with file_path.open("w") as f: - parser.dump(predictions, f) - - bulk_import_request = configured_project.upload_annotations( - name=name, annotations=str(file_path), validate=False) - - assert bulk_import_request.project() == configured_project - assert bulk_import_request.name == name - assert bulk_import_request.error_file_url is None - assert bulk_import_request.status_file_url is None - assert bulk_import_request.state == BulkImportRequestState.RUNNING - annotation_import_test_helpers.assert_file_content( - bulk_import_request.input_file_url, predictions) - - -def test_get(client, configured_project_with_one_data_row): - name = str(uuid.uuid4()) - url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" - configured_project_with_one_data_row.upload_annotations(name=name, - annotations=url, - validate=False) - - bulk_import_request = BulkImportRequest.from_name( - client, project_id=configured_project_with_one_data_row.uid, name=name) - - assert bulk_import_request.project() == configured_project_with_one_data_row - assert bulk_import_request.name == name - assert bulk_import_request.input_file_url == url - assert bulk_import_request.error_file_url is None - assert bulk_import_request.status_file_url is None - assert bulk_import_request.state == BulkImportRequestState.RUNNING - - -def test_validate_ndjson(tmp_path, configured_project_with_one_data_row): - file_name = f"broken.ndjson" - file_path = tmp_path / file_name - with file_path.open("w") as f: - f.write("test") - - with pytest.raises(ValueError): - configured_project_with_one_data_row.upload_annotations( - name="name", validate=True, annotations=str(file_path)) - - -def test_validate_ndjson_uuid(tmp_path, configured_project, predictions): - file_name = f"repeat_uuid.ndjson" - file_path = tmp_path / file_name - repeat_uuid = predictions.copy() - uid = str(uuid.uuid4()) - repeat_uuid[0]['uuid'] = uid - repeat_uuid[1]['uuid'] = uid - - with file_path.open("w") as f: - parser.dump(repeat_uuid, f) - - with pytest.raises(UuidError): - configured_project.upload_annotations(name="name", - validate=True, - annotations=str(file_path)) - - with pytest.raises(UuidError): - configured_project.upload_annotations(name="name", - validate=True, - annotations=repeat_uuid) - - -@pytest.mark.slow -def test_wait_till_done(rectangle_inference, - configured_project_with_one_data_row): - name = str(uuid.uuid4()) - url = configured_project_with_one_data_row.client.upload_data( - content=parser.dumps([rectangle_inference]), sign=True) - bulk_import_request = configured_project_with_one_data_row.upload_annotations( - name=name, annotations=url, validate=False) - - assert len(bulk_import_request.inputs) == 1 - bulk_import_request.wait_until_done() - assert bulk_import_request.state == BulkImportRequestState.FINISHED - - # Check that the status files are being returned as expected - assert len(bulk_import_request.errors) == 0 - assert len(bulk_import_request.inputs) == 1 - assert bulk_import_request.inputs[0]['uuid'] == rectangle_inference['uuid'] - assert len(bulk_import_request.statuses) == 1 - assert bulk_import_request.statuses[0]['status'] == 'SUCCESS' - assert bulk_import_request.statuses[0]['uuid'] == rectangle_inference[ - 'uuid'] - - -def test_project_bulk_import_requests(configured_project, predictions): - result = configured_project.bulk_import_requests() - assert len(list(result)) == 0 - - name = str(uuid.uuid4()) - bulk_import_request = configured_project.upload_annotations( - name=name, annotations=predictions) - bulk_import_request.wait_until_done() - - name = str(uuid.uuid4()) - bulk_import_request = configured_project.upload_annotations( - name=name, annotations=predictions) - bulk_import_request.wait_until_done() - - name = str(uuid.uuid4()) - bulk_import_request = configured_project.upload_annotations( - name=name, annotations=predictions) - bulk_import_request.wait_until_done() - - result = configured_project.bulk_import_requests() - assert len(list(result)) == 3 - - -def test_delete(configured_project, predictions): - name = str(uuid.uuid4()) - - bulk_import_request = configured_project.upload_annotations( - name=name, annotations=predictions) - bulk_import_request.wait_until_done() - all_import_requests = configured_project.bulk_import_requests() - assert len(list(all_import_requests)) == 1 - - bulk_import_request.delete() - all_import_requests = configured_project.bulk_import_requests() - assert len(list(all_import_requests)) == 0 - - -def test_pdf_mal_bbox(client, configured_project_pdf): - """ - tests pdf mal against only a bbox annotation - """ - annotations = [] - num_annotations = 1 - - for row in configured_project_pdf.export_queued_data_rows(): - for _ in range(num_annotations): - annotations.append({ - "uuid": str(uuid.uuid4()), - "name": "bbox", - "dataRow": { - "id": row['id'] - }, - "bbox": { - "top": round(random.uniform(0, 300), 2), - "left": round(random.uniform(0, 300), 2), - "height": round(random.uniform(200, 500), 2), - "width": round(random.uniform(0, 200), 2) - }, - "page": random.randint(0, 1), - "unit": "POINTS" - }) - annotations.extend([ - { #annotations intended to test classifications - 'name': 'text', - 'answer': 'the answer to the text question', - 'uuid': 'fc1913c6-b735-4dea-bd25-c18152a4715f', - "dataRow": { - "id": row['id'] - } - }, - { - 'name': 'checklist', - 'uuid': '9d7b2e57-d68f-4388-867a-af2a9b233719', - "dataRow": { - "id": row['id'] - }, - 'answer': [{ - 'name': 'option1' - }, { - 'name': 'optionN' - }] - }, - { - 'name': 'radio', - 'answer': { - 'name': 'second_radio_answer' - }, - 'uuid': 'ad60897f-ea1a-47de-b923-459339764921', - "dataRow": { - "id": row['id'] - } - }, - { #adding this with the intention to ensure we allow page: 0 - "uuid": str(uuid.uuid4()), - "name": "bbox", - "dataRow": { - "id": row['id'] - }, - "bbox": { - "top": round(random.uniform(0, 300), 2), - "left": round(random.uniform(0, 300), 2), - "height": round(random.uniform(200, 500), 2), - "width": round(random.uniform(0, 200), 2) - }, - "page": 0, - "unit": "POINTS" - } - ]) - import_annotations = MALPredictionImport.create_from_objects( - client=client, - project_id=configured_project_pdf.uid, - name=f"import {str(uuid.uuid4())}", - predictions=annotations) - import_annotations.wait_until_done() - - assert import_annotations.errors == [] - - -def test_pdf_document_entity(client, configured_project_with_one_data_row, - dataset_pdf_entity, rand_gen): - # for content "Metal-insulator (MI) transitions have been one of the" in OCR JSON extract tests/assets/arxiv-pdf_data_99-word-token-pdfs_0801.3483-lb-textlayer.json - document_text_selection = DocumentTextSelection( - group_id="2f4336f4-a07e-4e0a-a9e1-5629b03b719b", - token_ids=[ - "3f984bf3-1d61-44f5-b59a-9658a2e3440f", - "3bf00b56-ff12-4e52-8cc1-08dbddb3c3b8", - "6e1c3420-d4b7-4c5a-8fd6-ead43bf73d80", - "87a43d32-af76-4a1d-b262-5c5f4d5ace3a", - "e8606e8a-dfd9-4c49-a635-ad5c879c75d0", - "67c7c19e-4654-425d-bf17-2adb8cf02c30", - "149c5e80-3e07-49a7-ab2d-29ddfe6a38fa", - "b0e94071-2187-461e-8e76-96c58738a52c" - ], - page=1) - - entities_annotation_document_entity = DocumentEntity( - text_selections=[document_text_selection]) - entities_annotation = ObjectAnnotation( - name="named-entity", value=entities_annotation_document_entity) - - labels = [] - _, data_row_uids = dataset_pdf_entity - configured_project_with_one_data_row.create_batch( - rand_gen(str), - data_row_uids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - - for data_row_uid in data_row_uids: - labels.append( - Label(data=TextData(uid=data_row_uid), - annotations=[ - entities_annotation, - ])) - - import_annotations = MALPredictionImport.create_from_objects( - client=client, - project_id=configured_project_with_one_data_row.uid, - name=f"import {str(uuid.uuid4())}", - predictions=labels) - import_annotations.wait_until_done() - - assert import_annotations.errors == [] - - -def test_nested_video_object_annotations(client, - configured_project_with_one_data_row, - video_data, - bbox_video_annotation_objects, - rand_gen): - labels = [] - _, data_row_uids = video_data - configured_project_with_one_data_row.update(media_type=MediaType.Video) - configured_project_with_one_data_row.create_batch( - rand_gen(str), - data_row_uids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - - for data_row_uid in data_row_uids: - labels.append( - Label(data=VideoData(uid=data_row_uid), - annotations=bbox_video_annotation_objects)) - import_annotations = MALPredictionImport.create_from_objects( - client=client, - project_id=configured_project_with_one_data_row.uid, - name=f"import {str(uuid.uuid4())}", - predictions=labels) - import_annotations.wait_until_done() - - assert import_annotations.errors == [] - - -def _create_label(row_index, data_row_uids, label_name_ids=['bbox']): - label_name = label_name_ids[row_index % len(label_name_ids)] - data_row_uid = data_row_uids[row_index % len(data_row_uids)] - return Label(data=VideoData(uid=data_row_uid), - annotations=[ - VideoObjectAnnotation(name=label_name, - keyframe=True, - frame=4, - segment_index=0, - value=Rectangle( - start=Point(x=100, y=100), - end=Point(x=105, y=105), - )) - ]) - - -@patch('labelbox.schema.annotation_import.ANNOTATION_PER_LABEL_LIMIT', 20) -def test_below_annotation_limit_on_single_data_row( - client, configured_project_with_one_data_row, video_data, rand_gen): - _, data_row_uids = video_data - configured_project_with_one_data_row.update(media_type=MediaType.Video) - configured_project_with_one_data_row.create_batch( - rand_gen(str), - data_row_uids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - labels = [_create_label(index, data_row_uids) for index in range(19)] - import_annotations = MALPredictionImport.create_from_objects( - client=client, - project_id=configured_project_with_one_data_row.uid, - name=f"import {str(uuid.uuid4())}", - predictions=labels) - import_annotations.wait_until_done() - - assert import_annotations.errors == [] - - -@patch('labelbox.schema.annotation_import.ANNOTATION_PER_LABEL_LIMIT', 20) -def test_above_annotation_limit_on_single_label_on_single_data_row( - client, configured_project_with_one_data_row, video_data, rand_gen): - _, data_row_uids = video_data - - configured_project_with_one_data_row.update(media_type=MediaType.Video) - configured_project_with_one_data_row.create_batch( - rand_gen(str), - data_row_uids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - labels = [_create_label(index, data_row_uids) for index in range(21)] - with pytest.raises(ValueError): - import_annotations = MALPredictionImport.create_from_objects( - client=client, - project_id=configured_project_with_one_data_row.uid, - name=f"import {str(uuid.uuid4())}", - predictions=labels) - import_annotations.wait_until_done() - - -@patch('labelbox.schema.annotation_import.ANNOTATION_PER_LABEL_LIMIT', 20) -def test_above_annotation_limit_divided_among_different_rows( - client, configured_project_with_one_data_row, video_data_100_rows, - rand_gen): - _, data_row_uids = video_data_100_rows - - configured_project_with_one_data_row.update(media_type=MediaType.Video) - configured_project_with_one_data_row.create_batch( - rand_gen(str), - data_row_uids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - labels = [_create_label(index, data_row_uids) for index in range(21)] - - import_annotations = MALPredictionImport.create_from_objects( - client=client, - project_id=configured_project_with_one_data_row.uid, - name=f"import {str(uuid.uuid4())}", - predictions=labels) - - assert import_annotations.errors == [] - - -@patch('labelbox.schema.annotation_import.ANNOTATION_PER_LABEL_LIMIT', 20) -def test_above_annotation_limit_divided_among_labels_on_one_row( - client, configured_project_with_one_data_row, video_data, rand_gen): - _, data_row_uids = video_data - - configured_project_with_one_data_row.update(media_type=MediaType.Video) - configured_project_with_one_data_row.create_batch( - rand_gen(str), - data_row_uids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - labels = [ - _create_label(index, - data_row_uids, - label_name_ids=['bbox', 'bbox_tool_with_nested_text']) - for index in range(21) - ] - - import_annotations = MALPredictionImport.create_from_objects( - client=client, - project_id=configured_project_with_one_data_row.uid, - name=f"import {str(uuid.uuid4())}", - predictions=labels) - - assert import_annotations.errors == [] - ----- -tests/integration/annotation_import/fixtures/__init__.py - ----- -tests/integration/annotation_import/fixtures/video_annotations.py -import pytest -from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnnotation, ClassificationAnswer, Radio -from labelbox.data.annotation_types.geometry.point import Point -from labelbox.data.annotation_types.geometry.rectangle import Rectangle - -from labelbox.data.annotation_types.video import VideoObjectAnnotation - - -@pytest.fixture -def bbox_video_annotation_objects(): - bbox_annotation = [ - VideoObjectAnnotation( - name="bbox", - keyframe=True, - frame=13, - segment_index=0, - value=Rectangle( - start=Point(x=146.0, y=98.0), # Top left - end=Point(x=382.0, y=341.0), # Bottom right - ), - classifications=[ - ClassificationAnnotation( - name='nested', - value=Radio(answer=ClassificationAnswer( - name='radio_option_1', - classifications=[ - ClassificationAnnotation( - name='nested_checkbox', - value=Checklist(answer=[ - ClassificationAnswer( - name='nested_checkbox_option_1'), - ClassificationAnswer( - name='nested_checkbox_option_2') - ])) - ])), - ) - ]), - VideoObjectAnnotation( - name="bbox", - keyframe=True, - frame=19, - segment_index=0, - value=Rectangle( - start=Point(x=186.0, y=98.0), # Top left - end=Point(x=490.0, y=341.0), # Bottom right - )) - ] - - return bbox_annotation - ----- -tests/integration/annotation_import/fixtures/export_v2.py -import pytest - - -@pytest.fixture() -def expected_export_v2_image(): - exported_annotations = { - 'objects': [{ - 'name': - 'polygon', - 'value': - 'polygon', - 'annotation_kind': - 'ImagePolygon', - 'classifications': [], - 'polygon': [{ - 'x': 147.692, - 'y': 118.154 - }, { - 'x': 142.769, - 'y': 104.923 - }, { - 'x': 57.846, - 'y': 118.769 - }, { - 'x': 28.308, - 'y': 169.846 - }, { - 'x': 147.692, - 'y': 118.154 - }] - }, { - 'name': 'bbox', - 'value': 'bbox', - 'annotation_kind': 'ImageBoundingBox', - 'classifications': [{ - 'name': 'nested', - 'value': 'nested', - 'radio_answer': { - 'name': 'radio_option_1', - 'value': 'radio_value_1', - 'classifications': [] - } - }], - 'bounding_box': { - 'top': 48.0, - 'left': 58.0, - 'height': 65.0, - 'width': 12.0 - } - }, { - 'name': 'polyline', - 'value': 'polyline', - 'annotation_kind': 'ImagePolyline', - 'classifications': [], - 'line': [{ - 'x': 147.692, - 'y': 118.154 - }, { - 'x': 150.692, - 'y': 160.154 - }] - }], - 'classifications': [{ - 'name': - 'checklist', - 'value': - 'checklist', - 'checklist_answers': [{ - 'name': 'option1', - 'value': 'option1', - 'classifications': [] - }] - }, { - 'name': 'text', - 'value': 'text', - 'text_answer': { - 'content': 'free form text...' - } - }], - 'relationships': [] - } - - return exported_annotations - - -@pytest.fixture() -def expected_export_v2_audio(): - expected_annotations = { - 'objects': [], - 'classifications': [{ - 'name': - 'checklist', - 'value': - 'checklist', - 'checklist_answers': [{ - 'name': 'option1', - 'value': 'option1', - 'classifications': [] - }] - }, { - 'name': 'text', - 'value': 'text', - 'text_answer': { - 'content': 'free form text...' - } - }], - 'relationships': [] - } - return expected_annotations - - -@pytest.fixture() -def expected_export_v2_html(): - expected_annotations = { - 'objects': [], - 'classifications': [{ - 'name': 'text', - 'value': 'text', - 'text_answer': { - 'content': 'free form text...' - } - }, { - 'name': - 'checklist', - 'value': - 'checklist', - 'checklist_answers': [{ - 'name': 'option1', - 'value': 'option1', - 'classifications': [] - }] - }], - 'relationships': [] - } - return expected_annotations - - -@pytest.fixture() -def expected_export_v2_text(): - expected_annotations = { - 'objects': [{ - 'name': 'named-entity', - 'value': 'named_entity', - 'annotation_kind': 'TextEntity', - 'classifications': [], - 'location': { - 'start': - 66, - 'end': - 128, - 'token': - "more people to express themselves online😞😂‚, research suggests" - } - }], - 'classifications': [{ - 'name': - 'checklist', - 'value': - 'checklist', - 'checklist_answers': [{ - 'name': 'option1', - 'value': 'option1', - 'classifications': [] - }] - }, { - 'name': 'text', - 'value': 'text', - 'text_answer': { - 'content': 'free form text...' - } - }], - 'relationships': [] - } - return expected_annotations - - -@pytest.fixture() -def expected_export_v2_video(): - expected_annotations = { - 'frames': {}, - 'segments': { - '': [[7, 13], [18, 19]] - }, - 'key_frame_feature_map': {}, - 'classifications': [{ - 'name': - 'checklist', - 'value': - 'checklist', - 'checklist_answers': [{ - 'name': 'option1', - 'value': 'option1', - 'classifications': [] - }] - }] - } - return expected_annotations - - -@pytest.fixture() -def expected_export_v2_conversation(): - expected_annotations = { - 'objects': [{ - 'name': 'named-entity', - 'value': 'named_entity', - 'annotation_kind': 'ConversationalTextEntity', - 'classifications': [], - 'conversational_location': { - 'message_id': '0', - 'location': { - 'start': 0, - 'end': 8 - } - } - }], - 'classifications': [{ - 'name': - 'checklist_index', - 'value': - 'checklist_index', - 'message_id': - '0', - 'conversational_checklist_answers': [{ - 'name': 'option1_index', - 'value': 'option1_index', - 'classifications': [] - }] - }, { - 'name': 'text_index', - 'value': 'text_index', - 'message_id': '0', - 'conversational_text_answer': { - 'content': 'free form text...' - } - }], - 'relationships': [] - } - return expected_annotations - - -@pytest.fixture() -def expected_export_v2_dicom(): - expected_annotations = { - 'groups': { - 'Axial': { - 'name': 'Axial', - 'classifications': [], - 'frames': { - '1': { - 'objects': { - '': { - 'name': - 'polyline', - 'value': - 'polyline', - 'annotation_kind': - 'DICOMPolyline', - 'classifications': [], - 'line': [{ - 'x': 147.692, - 'y': 118.154 - }, { - 'x': 150.692, - 'y': 160.154 - }] - } - }, - 'classifications': [] - } - } - }, - 'Sagittal': { - 'name': 'Sagittal', - 'classifications': [], - 'frames': {} - }, - 'Coronal': { - 'name': 'Coronal', - 'classifications': [], - 'frames': {} - } - }, - 'segments': { - 'Axial': { - '': [[1, 1]] - }, - 'Sagittal': {}, - 'Coronal': {} - }, - 'classifications': [], - 'key_frame_feature_map': { - '': { - 'Axial': { - '1': True - }, - 'Coronal': {}, - 'Sagittal': {} - } - } - } - return expected_annotations - - -@pytest.fixture() -def expected_export_v2_document(): - expected_annotations = { - 'objects': [{ - 'name': 'named-entity', - 'value': 'named_entity', - 'annotation_kind': 'DocumentEntityToken', - 'classifications': [], - 'location': { - 'groups': [{ - 'id': - '2f4336f4-a07e-4e0a-a9e1-5629b03b719b', - 'page_number': - 1, - 'tokens': [ - '3f984bf3-1d61-44f5-b59a-9658a2e3440f', - '3bf00b56-ff12-4e52-8cc1-08dbddb3c3b8', - '6e1c3420-d4b7-4c5a-8fd6-ead43bf73d80', - '87a43d32-af76-4a1d-b262-5c5f4d5ace3a', - 'e8606e8a-dfd9-4c49-a635-ad5c879c75d0', - '67c7c19e-4654-425d-bf17-2adb8cf02c30', - '149c5e80-3e07-49a7-ab2d-29ddfe6a38fa', - 'b0e94071-2187-461e-8e76-96c58738a52c' - ], - 'text': - 'Metal-insulator (MI) transitions have been one of the' - }] - } - }, { - 'name': 'bbox', - 'value': 'bbox', - 'annotation_kind': 'DocumentBoundingBox', - 'classifications': [{ - 'name': 'nested', - 'value': 'nested', - 'radio_answer': { - 'name': 'radio_option_1', - 'value': 'radio_value_1', - 'classifications': [] - } - }], - 'page_number': 1, - 'bounding_box': { - 'top': 48.0, - 'left': 58.0, - 'height': 65.0, - 'width': 12.0 - } - }], - 'classifications': [{ - 'name': - 'checklist', - 'value': - 'checklist', - 'checklist_answers': [{ - 'name': 'option1', - 'value': 'option1', - 'classifications': [] - }] - }, { - 'name': 'text', - 'value': 'text', - 'text_answer': { - 'content': 'free form text...' - } - }], - 'relationships': [] - } - return expected_annotations - - -@pytest.fixture() -def expected_export_v2_llm_prompt_creation(): - expected_annotations = { - 'objects': [], - 'classifications': [{ - 'name': - 'checklist', - 'value': - 'checklist', - 'checklist_answers': [{ - 'name': 'option1', - 'value': 'option1', - 'classifications': [] - }] - }, { - 'name': 'text', - 'value': 'text', - 'text_answer': { - 'content': 'free form text...' - } - }], - 'relationships': [] - } - return expected_annotations - - -@pytest.fixture() -def expected_export_v2_llm_prompt_response_creation(): - expected_annotations = { - 'objects': [], - 'classifications': [{ - 'name': - 'checklist', - 'value': - 'checklist', - 'checklist_answers': [{ - 'name': 'option1', - 'value': 'option1', - 'classifications': [] - }] - }, { - 'name': 'text', - 'value': 'text', - 'text_answer': { - 'content': 'free form text...' - } - }], - 'relationships': [] - } - return expected_annotations - - -@pytest.fixture() -def expected_export_v2_llm_response_creation(): - expected_annotations = { - 'objects': [], - 'classifications': [{ - 'name': - 'checklist', - 'value': - 'checklist', - 'checklist_answers': [{ - 'name': 'option1', - 'value': 'option1', - 'classifications': [] - }] - }, { - 'name': 'text', - 'value': 'text', - 'text_answer': { - 'content': 'free form text...' - } - }], - 'relationships': [] - } - return expected_annotations - ----- -tests/integration/annotation_import/fixtures/annotations.py -import pytest -from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnnotation, ClassificationAnswer, Radio -from labelbox.data.annotation_types.geometry.point import Point -from labelbox.data.annotation_types.geometry.rectangle import Rectangle - -from labelbox.data.annotation_types.video import VideoObjectAnnotation - - -@pytest.fixture -def bbox_video_annotation_objects(): - bbox_annotation = [ - VideoObjectAnnotation( - name="bbox", - keyframe=True, - frame=13, - segment_index=0, - value=Rectangle( - start=Point(x=146.0, y=98.0), # Top left - end=Point(x=382.0, y=341.0), # Bottom right - ), - classifications=[ - ClassificationAnnotation( - name='nested', - value=Radio(answer=ClassificationAnswer( - name='radio_option_1', - classifications=[ - ClassificationAnnotation( - name='nested_checkbox', - value=Checklist(answer=[ - ClassificationAnswer( - name='nested_checkbox_option_1'), - ClassificationAnswer( - name='nested_checkbox_option_2') - ])) - ])), - ) - ]), - VideoObjectAnnotation( - name="bbox", - keyframe=True, - frame=19, - segment_index=0, - value=Rectangle( - start=Point(x=186.0, y=98.0), # Top left - end=Point(x=490.0, y=341.0), # Bottom right - )) - ] - - return bbox_annotation - ----- -tests/integration/support/integration_client.py -import os -import re -import uuid -from enum import Enum -from typing import Tuple - -import requests - -from labelbox import Client - -EPHEMERAL_BASE_URL = "http://lb-api-public" - - -class Environ(Enum): - LOCAL = 'local' - PROD = 'prod' - STAGING = 'staging' - ONPREM = 'onprem' - CUSTOM = 'custom' - STAGING_EU = 'staging-eu' - EPHEMERAL = 'ephemeral' # Used for testing PRs with ephemeral environments - - -def ephemeral_endpoint() -> str: - return os.getenv('LABELBOX_TEST_BASE_URL', EPHEMERAL_BASE_URL) - - -def graphql_url(environ: str) -> str: - if environ == Environ.PROD: - return 'https://api.labelbox.com/graphql' - elif environ == Environ.STAGING: - return 'https://api.lb-stage.xyz/graphql' - elif environ == Environ.STAGING_EU: - return 'https://api.eu-de.lb-stage.xyz/graphql' - elif environ == Environ.ONPREM: - hostname = os.environ.get('LABELBOX_TEST_ONPREM_HOSTNAME', None) - if hostname is None: - raise Exception(f"Missing LABELBOX_TEST_ONPREM_INSTANCE") - return f"{hostname}/api/_gql" - elif environ == Environ.CUSTOM: - graphql_api_endpoint = os.environ.get( - 'LABELBOX_TEST_GRAPHQL_API_ENDPOINT') - if graphql_api_endpoint is None: - raise Exception(f"Missing LABELBOX_TEST_GRAPHQL_API_ENDPOINT") - return graphql_api_endpoint - elif environ == Environ.EPHEMERAL: - return f"{ephemeral_endpoint()}/graphql" - return 'http://host.docker.internal:8080/graphql' - - -def rest_url(environ: str) -> str: - if environ == Environ.PROD: - return 'https://api.labelbox.com/api/v1' - elif environ == Environ.STAGING: - return 'https://api.lb-stage.xyz/api/v1' - elif environ == Environ.STAGING_EU: - return 'https://api.eu-de.lb-stage.xyz/api/v1' - elif environ == Environ.CUSTOM: - rest_api_endpoint = os.environ.get('LABELBOX_TEST_REST_API_ENDPOINT') - if rest_api_endpoint is None: - raise Exception(f"Missing LABELBOX_TEST_REST_API_ENDPOINT") - return rest_api_endpoint - elif environ == Environ.EPHEMERAL: - return f"{ephemeral_endpoint()}/api/v1" - return 'http://host.docker.internal:8080/api/v1' - - -def testing_api_key(environ: str) -> str: - if environ == Environ.PROD: - return os.environ["LABELBOX_TEST_API_KEY_PROD"] - elif environ == Environ.STAGING: - return os.environ["LABELBOX_TEST_API_KEY_STAGING"] - elif environ == Environ.STAGING_EU: - return os.environ["LABELBOX_TEST_API_KEY_STAGING_EU"] - elif environ == Environ.ONPREM: - return os.environ["LABELBOX_TEST_API_KEY_ONPREM"] - elif environ == Environ.CUSTOM: - return os.environ["LABELBOX_TEST_API_KEY_CUSTOM"] - return os.environ["LABELBOX_TEST_API_KEY_LOCAL"] - - -def service_api_key() -> str: - return os.environ["SERVICE_API_KEY"] - - -class IntegrationClient(Client): - - def __init__(self, environ: str) -> None: - api_url = graphql_url(environ) - api_key = testing_api_key(environ) - rest_endpoint = rest_url(environ) - - super().__init__(api_key, - api_url, - enable_experimental=True, - rest_endpoint=rest_endpoint) - self.queries = [] - - def execute(self, query=None, params=None, check_naming=True, **kwargs): - if check_naming and query is not None: - assert re.match(r"\s*(?:query|mutation) \w+PyApi", - query) is not None - self.queries.append((query, params)) - return super().execute(query, params, **kwargs) - - -class AdminClient(Client): - - def __init__(self, env): - """ - The admin client creates organizations and users using admin api described here https://labelbox.atlassian.net/wiki/spaces/AP/pages/2206564433/Internal+Admin+APIs. - """ - self._api_key = service_api_key() - self._admin_endpoint = f"{ephemeral_endpoint()}/admin/v1" - self._api_url = graphql_url(env) - self._rest_endpoint = rest_url(env) - - super().__init__(self._api_key, - self._api_url, - enable_experimental=True, - rest_endpoint=self._rest_endpoint) - - def _create_organization(self) -> str: - endpoint = f"{self._admin_endpoint}/organizations/" - response = requests.post( - endpoint, - headers=self.headers, - json={"name": f"Test Org {uuid.uuid4()}"}, - ) - - data = response.json() - if response.status_code not in [ - requests.codes.created, requests.codes.ok - ]: - raise Exception("Failed to create org, message: " + - str(data['message'])) - - return data['id'] - - def _create_user(self, organization_id=None) -> Tuple[str, str]: - if organization_id is None: - organization_id = self.organization_id - - endpoint = f"{self._admin_endpoint}/user-identities/" - identity_id = f"e2e+{uuid.uuid4()}" - - response = requests.post( - endpoint, - headers=self.headers, - json={ - "identityId": identity_id, - "email": "email@email.com", - "name": f"tester{uuid.uuid4()}", - "verificationStatus": "VERIFIED", - }, - ) - data = response.json() - if response.status_code not in [ - requests.codes.created, requests.codes.ok - ]: - raise Exception("Failed to create user, message: " + - str(data['message'])) - - user_identity_id = data['identityId'] - - endpoint = f"{self._admin_endpoint}/organizations/{organization_id}/users/" - response = requests.post( - endpoint, - headers=self.headers, - json={ - "identityId": user_identity_id, - "organizationRole": "Admin" - }, - ) - - data = response.json() - if response.status_code not in [ - requests.codes.created, requests.codes.ok - ]: - raise Exception("Failed to create link user to org, message: " + - str(data['message'])) - - user_id = data['id'] - - endpoint = f"{self._admin_endpoint}/users/{user_id}/token" - response = requests.get( - endpoint, - headers=self.headers, - ) - data = response.json() - if response.status_code not in [ - requests.codes.created, requests.codes.ok - ]: - raise Exception("Failed to create ephemeral user, message: " + - str(data['message'])) - - token = data["token"] - - return user_id, token - - def create_api_key_for_user(self) -> str: - organization_id = self._create_organization() - _, user_token = self._create_user(organization_id) - key_name = f"test-key+{uuid.uuid4()}" - query = """mutation CreateApiKeyPyApi($name: String!) { - createApiKey(data: {name: $name}) { - id - jwt - } - } - """ - params = {"name": key_name} - self.headers["Authorization"] = f"Bearer {user_token}" - res = self.execute(query, params, error_log_key="errors") - - return res["createApiKey"]["jwt"] - - -class EphemeralClient(Client): - - def __init__(self, environ=Environ.EPHEMERAL): - self.admin_client = AdminClient(environ) - self.api_key = self.admin_client.create_api_key_for_user() - api_url = graphql_url(environ) - rest_endpoint = rest_url(environ) - - super().__init__(self.api_key, - api_url, - enable_experimental=True, - rest_endpoint=rest_endpoint) - ----- -tests/integration/support/__init__.py - ----- -tests/integration/export/conftest.py -import uuid -import time -import pytest -from labelbox.schema.queue_mode import QueueMode -from labelbox.schema.media_type import MediaType -from labelbox.schema.labeling_frontend import LabelingFrontend -from labelbox.schema.annotation_import import LabelImport, AnnotationImportState - - -@pytest.fixture -def ontology(): - bbox_tool_with_nested_text = { - 'required': - False, - 'name': - 'bbox_tool_with_nested_text', - 'tool': - 'rectangle', - 'color': - '#a23030', - 'classifications': [{ - 'required': - False, - 'instructions': - 'nested', - 'name': - 'nested', - 'type': - 'radio', - 'options': [{ - 'label': - 'radio_option_1', - 'value': - 'radio_value_1', - 'options': [{ - 'required': - False, - 'instructions': - 'nested_checkbox', - 'name': - 'nested_checkbox', - 'type': - 'checklist', - 'options': [{ - 'label': 'nested_checkbox_option_1', - 'value': 'nested_checkbox_value_1', - 'options': [] - }, { - 'label': 'nested_checkbox_option_2', - 'value': 'nested_checkbox_value_2' - }] - }, { - 'required': False, - 'instructions': 'nested_text', - 'name': 'nested_text', - 'type': 'text', - 'options': [] - }] - },] - }] - } - - bbox_tool = { - 'required': - False, - 'name': - 'bbox', - 'tool': - 'rectangle', - 'color': - '#a23030', - 'classifications': [{ - 'required': - False, - 'instructions': - 'nested', - 'name': - 'nested', - 'type': - 'radio', - 'options': [{ - 'label': - 'radio_option_1', - 'value': - 'radio_value_1', - 'options': [{ - 'required': - False, - 'instructions': - 'nested_checkbox', - 'name': - 'nested_checkbox', - 'type': - 'checklist', - 'options': [{ - 'label': 'nested_checkbox_option_1', - 'value': 'nested_checkbox_value_1', - 'options': [] - }, { - 'label': 'nested_checkbox_option_2', - 'value': 'nested_checkbox_value_2' - }] - }] - },] - }] - } - - polygon_tool = { - 'required': False, - 'name': 'polygon', - 'tool': 'polygon', - 'color': '#FF34FF', - 'classifications': [] - } - polyline_tool = { - 'required': False, - 'name': 'polyline', - 'tool': 'line', - 'color': '#FF4A46', - 'classifications': [] - } - point_tool = { - 'required': False, - 'name': 'point--', - 'tool': 'point', - 'color': '#008941', - 'classifications': [] - } - entity_tool = { - 'required': False, - 'name': 'entity--', - 'tool': 'named-entity', - 'color': '#006FA6', - 'classifications': [] - } - segmentation_tool = { - 'required': False, - 'name': 'segmentation--', - 'tool': 'superpixel', - 'color': '#A30059', - 'classifications': [] - } - raster_segmentation_tool = { - 'required': False, - 'name': 'segmentation_mask', - 'tool': 'raster-segmentation', - 'color': '#ff0000', - 'classifications': [] - } - checklist = { - 'required': - False, - 'instructions': - 'checklist', - 'name': - 'checklist', - 'type': - 'checklist', - 'options': [{ - 'label': 'option1', - 'value': 'option1' - }, { - 'label': 'option2', - 'value': 'option2' - }, { - 'label': 'optionN', - 'value': 'optionn' - }] - } - checklist_index = { - 'required': - False, - 'instructions': - 'checklist_index', - 'name': - 'checklist_index', - 'type': - 'checklist', - 'scope': - 'index', - 'options': [{ - 'label': 'option1_index', - 'value': 'option1_index' - }, { - 'label': 'option2_index', - 'value': 'option2_index' - }, { - 'label': 'optionN_index', - 'value': 'optionn_index' - }] - } - free_form_text = { - 'required': False, - 'instructions': 'text', - 'name': 'text', - 'type': 'text', - 'options': [] - } - free_form_text_index = { - 'required': False, - 'instructions': 'text_index', - 'name': 'text_index', - 'type': 'text', - 'scope': 'index', - 'options': [] - } - radio = { - 'required': - False, - 'instructions': - 'radio', - 'name': - 'radio', - 'type': - 'radio', - 'options': [{ - 'label': 'first_radio_answer', - 'value': 'first_radio_answer', - 'options': [] - }, { - 'label': 'second_radio_answer', - 'value': 'second_radio_answer', - 'options': [] - }] - } - named_entity = { - 'tool': 'named-entity', - 'name': 'named-entity', - 'required': False, - 'color': '#A30059', - 'classifications': [], - } - - tools = [ - bbox_tool, - bbox_tool_with_nested_text, - polygon_tool, - polyline_tool, - point_tool, - entity_tool, - segmentation_tool, - raster_segmentation_tool, - named_entity, - ] - classifications = [ - checklist, checklist_index, free_form_text, free_form_text_index, radio - ] - return {"tools": tools, "classifications": classifications} - - -@pytest.fixture -def polygon_inference(prediction_id_mapping): - polygon = prediction_id_mapping['polygon'].copy() - polygon.update({ - "polygon": [{ - "x": 147.692, - "y": 118.154 - }, { - "x": 142.769, - "y": 104.923 - }, { - "x": 57.846, - "y": 118.769 - }, { - "x": 28.308, - "y": 169.846 - }] - }) - del polygon['tool'] - return polygon - - -@pytest.fixture -def configured_project_with_ontology(client, initial_dataset, ontology, - rand_gen, image_url): - dataset = initial_dataset - project = client.create_project( - name=rand_gen(str), - queue_mode=QueueMode.Batch, - ) - editor = list( - client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - project.setup(editor, ontology) - data_row_ids = [] - - for _ in range(len(ontology['tools']) + len(ontology['classifications'])): - data_row_ids.append(dataset.create_data_row(row_data=image_url).uid) - project.create_batch( - rand_gen(str), - data_row_ids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - project.data_row_ids = data_row_ids - yield project - project.delete() - - -@pytest.fixture -def configured_project_without_data_rows(client, ontology, rand_gen): - project = client.create_project(name=rand_gen(str), - description=rand_gen(str), - queue_mode=QueueMode.Batch) - editor = list( - client.get_labeling_frontends( - where=LabelingFrontend.name == "editor"))[0] - project.setup(editor, ontology) - yield project - project.delete() - - -@pytest.fixture -def model_run_with_data_rows(client, configured_project_with_ontology, - model_run_predictions, model_run, - wait_for_label_processing): - configured_project_with_ontology.enable_model_assisted_labeling() - use_data_row_ids = [p['dataRow']['id'] for p in model_run_predictions] - model_run.upsert_data_rows(use_data_row_ids) - - upload_task = LabelImport.create_from_objects( - client, configured_project_with_ontology.uid, - f"label-import-{uuid.uuid4()}", model_run_predictions) - upload_task.wait_until_done() - assert upload_task.state == AnnotationImportState.FINISHED, "Label Import did not finish" - assert len( - upload_task.errors - ) == 0, f"Label Import {upload_task.name} failed with errors {upload_task.errors}" - labels = wait_for_label_processing(configured_project_with_ontology) - label_ids = [label.uid for label in labels] - model_run.upsert_labels(label_ids) - yield model_run, labels - model_run.delete() - # TODO: Delete resources when that is possible .. - - -@pytest.fixture -def model_run_predictions(polygon_inference, rectangle_inference, - line_inference): - # Not supporting mask since there isn't a signed url representing a seg mask to upload - return [polygon_inference, rectangle_inference, line_inference] - - -@pytest.fixture -def model(client, rand_gen, configured_project): - ontology = configured_project.ontology() - data = {"name": rand_gen(str), "ontology_id": ontology.uid} - model = client.create_model(data["name"], data["ontology_id"]) - yield model - try: - model.delete() - except: - # Already was deleted by the test - pass - - -@pytest.fixture -def model_run(rand_gen, model): - name = rand_gen(str) - model_run = model.create_model_run(name) - yield model_run - try: - model_run.delete() - except: - # Already was deleted by the test - pass - - -@pytest.fixture -def wait_for_label_processing(): - """ - Do not use. Only for testing. - - Returns project's labels as a list after waiting for them to finish processing. - If `project.labels()` is called before label is fully processed, - it may return an empty set - """ - - def func(project): - timeout_seconds = 10 - while True: - labels = list(project.labels()) - if len(labels) > 0: - return labels - timeout_seconds -= 2 - if timeout_seconds <= 0: - raise TimeoutError( - f"Timed out waiting for label for project '{project.uid}' to finish processing" - ) - time.sleep(2) - - return func - - -@pytest.fixture -def prediction_id_mapping(configured_project_with_ontology): - # Maps tool types to feature schema ids - project = configured_project_with_ontology - ontology = project.ontology().normalized - result = {} - - for idx, tool in enumerate(ontology['tools'] + ontology['classifications']): - if 'tool' in tool: - tool_type = tool['tool'] - else: - tool_type = tool[ - 'type'] if 'scope' not in tool else f"{tool['type']}_{tool['scope']}" # so 'checklist' of 'checklist_index' - - # TODO: remove this once we have a better way to associate multiple tools instances with a single tool type - if tool_type == 'rectangle': - value = { - "uuid": str(uuid.uuid4()), - "schemaId": tool['featureSchemaId'], - "name": tool['name'], - "dataRow": { - "id": project.data_row_ids[idx], - }, - 'tool': tool - } - if tool_type not in result: - result[tool_type] = [] - result[tool_type].append(value) - else: - result[tool_type] = { - "uuid": str(uuid.uuid4()), - "schemaId": tool['featureSchemaId'], - "name": tool['name'], - "dataRow": { - "id": project.data_row_ids[idx], - }, - 'tool': tool - } - return result - - -@pytest.fixture -def line_inference(prediction_id_mapping): - line = prediction_id_mapping['line'].copy() - line.update( - {"line": [{ - "x": 147.692, - "y": 118.154 - }, { - "x": 150.692, - "y": 160.154 - }]}) - del line['tool'] - return line - - -@pytest.fixture -def polygon_inference(prediction_id_mapping): - polygon = prediction_id_mapping['polygon'].copy() - polygon.update({ - "polygon": [{ - "x": 147.692, - "y": 118.154 - }, { - "x": 142.769, - "y": 104.923 - }, { - "x": 57.846, - "y": 118.769 - }, { - "x": 28.308, - "y": 169.846 - }] - }) - del polygon['tool'] - return polygon - - -def find_tool_by_name(tool_instances, name): - for tool in tool_instances: - if tool['name'] == name: - return tool - return None - - -@pytest.fixture -def rectangle_inference(prediction_id_mapping): - tool_instance = find_tool_by_name(prediction_id_mapping['rectangle'], - 'bbox') - rectangle = tool_instance.copy() - rectangle.update({ - "bbox": { - "top": 48, - "left": 58, - "height": 65, - "width": 12 - }, - 'classifications': [{ - "schemaId": - rectangle['tool']['classifications'][0]['featureSchemaId'], - "name": - rectangle['tool']['classifications'][0]['name'], - "answer": { - "schemaId": - rectangle['tool']['classifications'][0]['options'][0] - ['featureSchemaId'], - "name": - rectangle['tool']['classifications'][0]['options'][0] - ['value'] - } - }] - }) - del rectangle['tool'] - return rectangle - ----- -tests/integration/export/streamable/test_export_data_rows_streamable.py -import json -import time - -import pytest - -from labelbox import DataRow, ExportTask, StreamType - - -class TestExportDataRow: - - def test_with_data_row_object(self, client, data_row, - wait_for_data_row_processing): - data_row = wait_for_data_row_processing(client, data_row) - time.sleep(7) # temp fix for ES indexing delay - export_task = DataRow.export( - client=client, - data_rows=[data_row], - task_name="TestExportDataRow:test_with_data_row_object", - ) - export_task.wait_till_done() - assert export_task.status == "COMPLETE" - assert isinstance(export_task, ExportTask) - assert export_task.has_result() - assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - assert export_task.get_total_lines(stream_type=StreamType.RESULT) == 1 - assert (json.loads(list(export_task.get_stream())[0].json_str) - ["data_row"]["id"] == data_row.uid) - - def test_with_id(self, client, data_row, wait_for_data_row_processing): - data_row = wait_for_data_row_processing(client, data_row) - time.sleep(7) # temp fix for ES indexing delay - export_task = DataRow.export(client=client, - data_rows=[data_row.uid], - task_name="TestExportDataRow:test_with_id") - export_task.wait_till_done() - assert export_task.status == "COMPLETE" - assert isinstance(export_task, ExportTask) - assert export_task.has_result() - assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - assert export_task.get_total_lines(stream_type=StreamType.RESULT) == 1 - assert (json.loads(list(export_task.get_stream())[0].json_str) - ["data_row"]["id"] == data_row.uid) - - def test_with_global_key(self, client, data_row, - wait_for_data_row_processing): - data_row = wait_for_data_row_processing(client, data_row) - time.sleep(7) # temp fix for ES indexing delay - export_task = DataRow.export( - client=client, - global_keys=[data_row.global_key], - task_name="TestExportDataRow:test_with_global_key", - ) - export_task.wait_till_done() - assert export_task.status == "COMPLETE" - assert isinstance(export_task, ExportTask) - assert export_task.has_result() - assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - assert export_task.get_total_lines(stream_type=StreamType.RESULT) == 1 - assert (json.loads(list(export_task.get_stream())[0].json_str) - ["data_row"]["id"] == data_row.uid) - - def test_with_invalid_id(self, client): - export_task = DataRow.export( - client=client, - data_rows=["invalid_id"], - task_name="TestExportDataRow:test_with_invalid_id", - ) - export_task.wait_till_done() - assert export_task.status == "COMPLETE" - assert isinstance(export_task, ExportTask) - assert export_task.has_result() is False - assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) is None - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) is None - ----- -tests/integration/export/streamable/test_export_model_run_streamable.py -import json -import time - -from labelbox import ExportTask, StreamType - - -class TestExportModelRun: - - def test_export(self, model_run_with_data_rows): - model_run, labels = model_run_with_data_rows - label_ids = [label.uid for label in labels] - expected_data_rows = list(model_run.model_run_data_rows()) - - task_name = "TestExportModelRun:test_export" - params = {"media_attributes": True, "predictions": True} - export_task = model_run.export(task_name, params=params) - assert export_task.name == task_name - export_task.wait_till_done() - - assert export_task.status == "COMPLETE" - assert isinstance(export_task, ExportTask) - assert export_task.has_result() - assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == len(expected_data_rows) - - for data in export_task.get_stream(): - obj = json.loads(data.json_str) - assert "media_attributes" in obj and obj[ - "media_attributes"] is not None - exported_model_run = obj["experiments"][model_run.model_id]["runs"][ - model_run.uid] - task_label_ids_set = set( - map(lambda label: label["id"], exported_model_run["labels"])) - task_prediction_ids_set = set( - map(lambda prediction: prediction["id"], - exported_model_run["predictions"])) - for label_id in task_label_ids_set: - assert label_id in label_ids - for prediction_id in task_prediction_ids_set: - assert prediction_id in label_ids - ----- -tests/integration/export/streamable/test_export_video_streamable.py -import json - -import pytest - -import labelbox as lb -import labelbox.types as lb_types -from labelbox.data.annotation_types.data.video import VideoData -from labelbox.schema.annotation_import import AnnotationImportState -from labelbox.schema.export_task import ExportTask, StreamType - - -class TestExportVideo: - - @pytest.fixture - def user_id(self, client): - return client.get_user().uid - - @pytest.fixture - def org_id(self, client): - return client.get_organization().uid - - def test_export( - self, - client, - configured_project_without_data_rows, - video_data, - video_data_row, - bbox_video_annotation_objects, - rand_gen, - ): - project = configured_project_without_data_rows - project_id = project.uid - labels = [] - - _, data_row_uids = video_data - project.create_batch( - rand_gen(str), - data_row_uids, # sample of data row objects - 5, # priority between 1(Highest) - 5(lowest) - ) - - for data_row_uid in data_row_uids: - labels = [ - lb_types.Label(data=VideoData(uid=data_row_uid), - annotations=bbox_video_annotation_objects) - ] - - label_import = lb.LabelImport.create_from_objects( - client, project_id, f"test-import-{project_id}", labels) - label_import.wait_until_done() - - assert label_import.state == AnnotationImportState.FINISHED - assert len(label_import.errors) == 0 - - export_task = project.export( - params={ - "performance_details": False, - "label_details": True, - "interpolated_frames": True, - }, - task_name="TestExportVideo:test_export", - ) - export_task.wait_till_done() - assert export_task.status == "COMPLETE" - assert isinstance(export_task, ExportTask) - assert export_task.has_result() - assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - - export_data = json.loads(list(export_task.get_stream())[0].json_str) - data_row_export = export_data["data_row"] - assert data_row_export["global_key"] == video_data_row["global_key"] - assert data_row_export["row_data"] == video_data_row["row_data"] - assert export_data["media_attributes"]["mime_type"] == "video/mp4" - assert export_data["media_attributes"][ - "frame_rate"] == 10 # as per the video_data fixture - assert (export_data["media_attributes"]["frame_count"] == 100 - ) # as per the video_data fixture - expected_export_label = { - "label_kind": "Video", - "version": "1.0.0", - "id": "clgjnpysl000xi3zxtnp29fug", - "label_details": { - "created_at": "2023-04-16T17:04:23+00:00", - "updated_at": "2023-04-16T17:04:23+00:00", - "created_by": "vbrodsky@labelbox.com", - "content_last_updated_at": "2023-04-16T17:04:23+00:00", - "reviews": [], - }, - "annotations": { - "frames": { - "13": { - "objects": { - "clgjnpyse000ui3zx6fr1d880": { - "feature_id": "clgjnpyse000ui3zx6fr1d880", - "name": "bbox", - "annotation_kind": "VideoBoundingBox", - "classifications": [{ - "feature_id": "clgjnpyse000vi3zxtgtfh01y", - "name": "nested", - "radio_answer": { - "feature_id": - "clgjnpyse000wi3zxnxgv53ps", - "name": - "radio_option_1", - "classifications": [], - }, - }], - "bounding_box": { - "top": 98.0, - "left": 146.0, - "height": 243.0, - "width": 236.0, - }, - } - }, - "classifications": [], - }, - "18": { - "objects": { - "clgjnpyse000ui3zx6fr1d880": { - "feature_id": "clgjnpyse000ui3zx6fr1d880", - "name": "bbox", - "annotation_kind": "VideoBoundingBox", - "classifications": [{ - "feature_id": "clgjnpyse000vi3zxtgtfh01y", - "name": "nested", - "radio_answer": { - "feature_id": - "clgjnpyse000wi3zxnxgv53ps", - "name": - "radio_option_1", - "classifications": [], - }, - }], - "bounding_box": { - "top": 98.0, - "left": 146.0, - "height": 243.0, - "width": 236.0, - }, - } - }, - "classifications": [], - }, - "19": { - "objects": { - "clgjnpyse000ui3zx6fr1d880": { - "feature_id": "clgjnpyse000ui3zx6fr1d880", - "name": "bbox", - "annotation_kind": "VideoBoundingBox", - "classifications": [], - "bounding_box": { - "top": 98.0, - "left": 146.0, - "height": 243.0, - "width": 236.0, - }, - } - }, - "classifications": [], - }, - }, - "segments": { - "clgjnpyse000ui3zx6fr1d880": [[13, 13], [18, 19]] - }, - "key_frame_feature_map": { - "clgjnpyse000ui3zx6fr1d880": { - "13": True, - "18": False, - "19": True - } - }, - "classifications": [], - }, - } - - project_export_labels = export_data["projects"][project_id]["labels"] - assert len(project_export_labels) == len( - labels - ) # note we create 1 label per data row, 1 data row so 1 label - export_label = project_export_labels[0] - assert (export_label["label_kind"]) == "Video" - - assert (export_label["label_details"].keys() - ) == expected_export_label["label_details"].keys() - - expected_frames_ids = [ - vannotation.frame for vannotation in bbox_video_annotation_objects - ] - export_annotations = export_label["annotations"] - export_frames = export_annotations["frames"] - export_frames_ids = [int(frame_id) for frame_id in export_frames.keys()] - all_frames_exported = [] - for (value) in ( - expected_frames_ids - ): # note need to understand why we are exporting more frames than we created - if value not in export_frames_ids: - all_frames_exported.append(value) - assert len(all_frames_exported) == 0 - - # BEGINNING OF THE VIDEO INTERPOLATION ASSERTIONS - first_frame_id = bbox_video_annotation_objects[0].frame - last_frame_id = bbox_video_annotation_objects[-1].frame - - # Generate list of frames with frames in between, e.g. 13, 14, 15, 16, 17, 18, 19 - expected_frame_ids = list(range(first_frame_id, last_frame_id + 1)) - - assert export_frames_ids == expected_frame_ids - - exported_objects_dict = export_frames[str(first_frame_id)]["objects"] - - # Get the label ID - first_exported_label_id = list(exported_objects_dict.keys())[0] - - # Since the bounding box moves to the right, the interpolated frame content should start - # a little bit more far to the right - assert (export_frames[str(first_frame_id + 1)]["objects"] - [first_exported_label_id]["bounding_box"]["left"] - > export_frames[str(first_frame_id)]["objects"] - [first_exported_label_id]["bounding_box"]["left"]) - # But it shouldn't be further than the last frame - assert (export_frames[str(first_frame_id + 1)]["objects"] - [first_exported_label_id]["bounding_box"]["left"] - < export_frames[str(last_frame_id)]["objects"] - [first_exported_label_id]["bounding_box"]["left"]) - # END OF THE VIDEO INTERPOLATION ASSERTIONS - - frame_with_nested_classifications = export_frames["13"] - annotation = None - for _, a in frame_with_nested_classifications["objects"].items(): - if a["name"] == "bbox": - annotation = a - break - assert annotation is not None - assert annotation["annotation_kind"] == "VideoBoundingBox" - assert annotation["classifications"] - assert annotation["bounding_box"] == { - "top": 98.0, - "left": 146.0, - "height": 243.0, - "width": 236.0, - } - classifications = annotation["classifications"] - classification = classifications[0]["radio_answer"] - assert classification["name"] == "radio_option_1" - subclassifications = classification["classifications"] - # NOTE predictions services does not support nested classifications at the moment, see - # https://labelbox.atlassian.net/browse/AL-5588 - assert len(subclassifications) == 0 - ----- -tests/integration/export/streamable/test_export_dataset_streamable.py -import json - -import pytest - -from labelbox import ExportTask, StreamType - - -class TestExportDataset: - - @pytest.mark.parametrize("data_rows", [3], indirect=True) - def test_export(self, dataset, data_rows): - expected_data_row_ids = [dr.uid for dr in data_rows] - - export_task = dataset.export(task_name="TestExportDataset:test_export") - export_task.wait_till_done() - - assert export_task.status == "COMPLETE" - assert isinstance(export_task, ExportTask) - assert export_task.has_result() - assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == len(expected_data_row_ids) - data_row_ids = list( - map(lambda x: json.loads(x.json_str)["data_row"]["id"], - export_task.get_stream())) - assert data_row_ids.sort() == expected_data_row_ids.sort() - - @pytest.mark.parametrize("data_rows", [3], indirect=True) - def test_with_data_row_filter(self, dataset, data_rows): - datarow_filter_size = 3 - expected_data_row_ids = [dr.uid for dr in data_rows - ][:datarow_filter_size] - filters = {"data_row_ids": expected_data_row_ids} - - export_task = dataset.export( - filters=filters, - task_name="TestExportDataset:test_with_data_row_filter") - export_task.wait_till_done() - - assert export_task.status == "COMPLETE" - assert isinstance(export_task, ExportTask) - assert export_task.has_result() - assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == datarow_filter_size - data_row_ids = list( - map(lambda x: json.loads(x.json_str)["data_row"]["id"], - export_task.get_stream())) - assert data_row_ids.sort() == expected_data_row_ids.sort() - - @pytest.mark.parametrize("data_rows", [3], indirect=True) - def test_with_global_key_filter(self, dataset, data_rows): - datarow_filter_size = 2 - expected_global_keys = [dr.global_key for dr in data_rows - ][:datarow_filter_size] - filters = {"global_keys": expected_global_keys} - - export_task = dataset.export( - filters=filters, - task_name="TestExportDataset:test_with_global_key_filter") - export_task.wait_till_done() - - assert export_task.status == "COMPLETE" - assert isinstance(export_task, ExportTask) - assert export_task.has_result() - assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == datarow_filter_size - global_keys = list( - map(lambda x: json.loads(x.json_str)["data_row"]["global_key"], - export_task.get_stream())) - assert global_keys.sort() == expected_global_keys.sort() - ----- -tests/integration/export/streamable/test_export_project_streamable.py -from datetime import datetime, timezone, timedelta -import json - -import pytest -import uuid -from typing import Tuple -from labelbox.schema.export_task import ExportTask, StreamType - -from labelbox.schema.media_type import MediaType -from labelbox import Project, Dataset -from labelbox.schema.data_row import DataRow -from labelbox.schema.label import Label - -IMAGE_URL = ( - "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/potato.jpeg" -) - - -class TestExportProject: - - @pytest.fixture - def project_export(self): - - def _project_export(project, task_name, filters=None, params=None): - export_task = project.export( - task_name=task_name, - filters=filters, - params=params, - ) - export_task.wait_till_done() - - assert export_task.status == "COMPLETE" - assert isinstance(export_task, ExportTask) - return export_task - - return _project_export - - def test_export( - self, - client, - configured_project_with_label, - wait_for_data_row_processing, - project_export, - ): - project, dataset, data_row, label = configured_project_with_label - data_row = wait_for_data_row_processing(client, data_row) - label_id = label.uid - task_name = "TestExportProject:test_export" - params = { - "include_performance_details": True, - "include_labels": True, - "media_type_override": MediaType.Image, - "project_details": True, - "data_row_details": True, - } - export_task = project_export(project, task_name, params=params) - assert export_task.has_result() - assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - assert export_task.get_total_lines(stream_type=StreamType.RESULT) > 0 - - for data in export_task.get_stream(): - obj = json.loads(data.json_str) - task_media_attributes = obj["media_attributes"] - task_project = obj["projects"][project.uid] - task_project_label_ids_set = set( - map(lambda prediction: prediction["id"], - task_project["labels"])) - task_project_details = task_project["project_details"] - task_data_row = obj["data_row"] - task_data_row_details = task_data_row["details"] - - assert label_id in task_project_label_ids_set - # data row - assert task_data_row["id"] == data_row.uid - assert task_data_row["external_id"] == data_row.external_id - assert task_data_row["row_data"] == data_row.row_data - - # data row details - assert task_data_row_details["dataset_id"] == dataset.uid - assert task_data_row_details["dataset_name"] == dataset.name - - actual_time = datetime.fromisoformat( - task_data_row_details["created_at"]) - expected_time = datetime.fromisoformat( - dataset.created_at.strftime("%Y-%m-%dT%H:%M:%S.%f")) - actual_time = actual_time.replace(tzinfo=timezone.utc) - expected_time = expected_time.replace(tzinfo=timezone.utc) - tolerance = timedelta(seconds=2) - assert abs(actual_time - expected_time) <= tolerance - - assert task_data_row_details["last_activity_at"] is not None - assert task_data_row_details["created_by"] is not None - - # media attributes - assert task_media_attributes[ - "mime_type"] == data_row.media_attributes["mimeType"] - - # project name and details - assert task_project["name"] == project.name - batch = next(project.batches()) - assert task_project_details["batch_id"] == batch.uid - assert task_project_details["batch_name"] == batch.name - assert task_project_details["priority"] is not None - assert task_project_details[ - "consensus_expected_label_count"] is not None - assert task_project_details["workflow_history"] is not None - - # label details - assert task_project["labels"][0]["id"] == label_id - - def test_with_date_filters( - self, - client, - configured_project_with_label, - wait_for_data_row_processing, - project_export, - ): - project, _, data_row, label = configured_project_with_label - data_row = wait_for_data_row_processing(client, data_row) - label_id = label.uid - task_name = "TestExportProject:test_with_date_filters" - filters = { - "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "task_queue_status": "InReview", - } - include_performance_details = True - params = { - "performance_details": include_performance_details, - "include_labels": True, - "project_details": True, - "media_type_override": MediaType.Image, - } - task_queues = project.task_queues() - review_queue = next( - tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE") - project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid) - export_task = project_export(project, - task_name, - filters=filters, - params=params) - assert export_task.has_result() - assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - assert export_task.get_total_lines(stream_type=StreamType.RESULT) > 0 - - for data in export_task.get_stream(): - obj = json.loads(data.json_str) - task_project = obj["projects"][project.uid] - task_project_label_ids_set = set( - map(lambda prediction: prediction["id"], - task_project["labels"])) - assert label_id in task_project_label_ids_set - assert task_project["project_details"][ - "workflow_status"] == "IN_REVIEW" - - def test_with_iso_date_filters( - self, - client, - configured_project_with_label, - wait_for_data_row_processing, - project_export, - ): - project, _, data_row, label = configured_project_with_label - data_row = wait_for_data_row_processing(client, data_row) - label_id = label.uid - task_name = "TestExportProject:test_with_iso_date_filters" - filters = { - "last_activity_at": [ - "2000-01-01T00:00:00+0230", "2050-01-01T00:00:00+0230" - ], - "label_created_at": [ - "2000-01-01T00:00:00+0230", "2050-01-01T00:00:00+0230" - ], - } - export_task = project_export(project, task_name, filters=filters) - assert export_task.has_result() - assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - assert export_task.get_total_lines(stream_type=StreamType.RESULT) > 0 - assert (label_id == json.loads( - list(export_task.get_stream())[0].json_str)["projects"][project.uid] - ["labels"][0]["id"]) - - def test_with_iso_date_filters_no_start_date( - self, - client, - configured_project_with_label, - wait_for_data_row_processing, - project_export, - ): - project, _, data_row, label = configured_project_with_label - data_row = wait_for_data_row_processing(client, data_row) - label_id = label.uid - task_name = "TestExportProject:test_with_iso_date_filters_no_start_date" - filters = {"last_activity_at": [None, "2050-01-01T00:00:00+0230"]} - export_task = project_export(project, task_name, filters=filters) - assert export_task.has_result() - assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - assert export_task.get_total_lines(stream_type=StreamType.RESULT) > 0 - assert (label_id == json.loads( - list(export_task.get_stream())[0].json_str)["projects"][project.uid] - ["labels"][0]["id"]) - - def test_with_iso_date_filters_and_future_start_date( - self, - client, - configured_project_with_label, - wait_for_data_row_processing, - project_export, - ): - project, _, data_row, _label = configured_project_with_label - data_row = wait_for_data_row_processing(client, data_row) - task_name = "TestExportProject:test_with_iso_date_filters_and_future_start_date" - filters = {"label_created_at": ["2050-01-01T00:00:00+0230", None]} - export_task = project_export(project, task_name, filters=filters) - assert export_task.has_result() is False - assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) is None - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) is None - - @pytest.mark.parametrize("data_rows", [3], indirect=True) - def test_with_data_row_filter( - self, configured_batch_project_with_multiple_datarows, - project_export): - project, _, data_rows = configured_batch_project_with_multiple_datarows - datarow_filter_size = 2 - expected_data_row_ids = [dr.uid for dr in data_rows - ][:datarow_filter_size] - task_name = "TestExportProject:test_with_data_row_filter" - filters = { - "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "data_row_ids": expected_data_row_ids, - } - params = { - "data_row_details": True, - "media_type_override": MediaType.Image - } - export_task = project_export(project, - task_name, - filters=filters, - params=params) - assert export_task.has_result() - assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - # only 2 datarows should be exported - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == datarow_filter_size - data_row_ids = list( - map(lambda x: json.loads(x.json_str)["data_row"]["id"], - export_task.get_stream())) - assert data_row_ids.sort() == expected_data_row_ids.sort() - - @pytest.mark.parametrize("data_rows", [3], indirect=True) - def test_with_global_key_filter( - self, configured_batch_project_with_multiple_datarows, - project_export): - project, _, data_rows = configured_batch_project_with_multiple_datarows - datarow_filter_size = 2 - expected_global_keys = [dr.global_key for dr in data_rows - ][:datarow_filter_size] - task_name = "TestExportProject:test_with_global_key_filter" - filters = { - "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "global_keys": expected_global_keys, - } - params = { - "data_row_details": True, - "media_type_override": MediaType.Image - } - export_task = project_export(project, - task_name, - filters=filters, - params=params) - assert export_task.has_result() - assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - # only 2 datarows should be exported - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == datarow_filter_size - global_keys = list( - map(lambda x: json.loads(x.json_str)["data_row"]["global_key"], - export_task.get_stream())) - assert global_keys.sort() == expected_global_keys.sort() - - def test_batch( - self, - configured_batch_project_with_label: Tuple[Project, Dataset, DataRow, - Label], - dataset: Dataset, - image_url: str, - project_export, - ): - project, dataset, *_ = configured_batch_project_with_label - batch = list(project.batches())[0] - filters = { - "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "batch_ids": [batch.uid], - } - params = { - "include_performance_details": True, - "include_labels": True, - "media_type_override": MediaType.Image, - } - task_name = "TestExportProject:test_batch" - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image" - }, - ] * 2) - task.wait_till_done() - data_rows = [dr.uid for dr in list(dataset.export_data_rows())] - batch_one = f"batch one {uuid.uuid4()}" - - # This test creates two batches, only one batch should be exporter - # Creatin second batch that will not be used in the export due to the filter: batch_id - project.create_batch(batch_one, data_rows) - - export_task = project_export(project, - task_name, - filters=filters, - params=params) - assert export_task.has_result() - assert export_task.has_errors() is False - assert export_task.get_total_file_size( - stream_type=StreamType.RESULT) > 0 - assert export_task.get_total_lines( - stream_type=StreamType.RESULT) == batch.size - ----- -tests/integration/export/legacy/test_export_catalog.py -import pytest - - -@pytest.mark.parametrize('data_rows', [3], indirect=True) -def test_catalog_export_v2(client, export_v2_test_helpers, data_rows): - datarow_filter_size = 2 - data_row_ids = [dr.uid for dr in data_rows] - - params = {"performance_details": False, "label_details": False} - filters = {"data_row_ids": data_row_ids[:datarow_filter_size]} - - task_results = export_v2_test_helpers.run_catalog_export_v2_task( - client, filters=filters, params=params) - - # only 2 datarows should be exported - assert len(task_results) == datarow_filter_size - # only filtered datarows should be exported - assert set([dr['data_row']['id'] for dr in task_results - ]) == set(data_row_ids[:datarow_filter_size]) - ----- -tests/integration/export/legacy/test_export_data_rows.py -import time -from labelbox import DataRow - - -def test_export_data_rows(client, data_row, wait_for_data_row_processing): - # Ensure created data rows are indexed - data_row = wait_for_data_row_processing(client, data_row) - time.sleep(7) # temp fix for ES indexing delay - - task = DataRow.export_v2(client=client, data_rows=[data_row]) - task.wait_till_done() - assert task.status == "COMPLETE" - assert task.errors is None - assert len(task.result) == 1 - assert task.result[0]["data_row"]["id"] == data_row.uid - - task = DataRow.export_v2(client=client, data_rows=[data_row.uid]) - task.wait_till_done() - assert task.status == "COMPLETE" - assert task.errors is None - assert len(task.result) == 1 - assert task.result[0]["data_row"]["id"] == data_row.uid - - task = DataRow.export_v2(client=client, global_keys=[data_row.global_key]) - task.wait_till_done() - assert task.status == "COMPLETE" - assert task.errors is None - assert len(task.result) == 1 - assert task.result[0]["data_row"]["id"] == data_row.uid - ----- -tests/integration/export/legacy/test_export_slice.py -import pytest - - -@pytest.mark.skip( - 'Skipping until we have a way to create slices programatically') -def test_export_v2_slice(client): - # Since we don't have CRUD for slices, we'll just use the one that's already there - SLICE_ID = "clk04g1e4000ryb0rgsvy1dty" - slice = client.get_catalog_slice(SLICE_ID) - task = slice.export_v2(params={ - "performance_details": False, - "label_details": True - }) - task.wait_till_done() - assert task.status == "COMPLETE" - assert task.errors is None - assert len(task.result) != 0 - ----- -tests/integration/export/legacy/test_legacy_export.py -import uuid -import datetime -import time -import requests -import pytest - -from labelbox.data.annotation_types.annotation import ObjectAnnotation -from labelbox.schema.annotation_import import LabelImport -from labelbox import Dataset, Project - -IMAGE_URL = "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/potato.jpeg" - - -def test_export_annotations_nested_checklist( - client, configured_project_with_complex_ontology, - wait_for_data_row_processing): - project, data_row = configured_project_with_complex_ontology - data_row = wait_for_data_row_processing(client, data_row) - - ontology = project.ontology().normalized - - tool = ontology["tools"][0] - - nested_check = [ - subc for subc in tool["classifications"] - if subc["name"] == "test-checklist-class" - ][0] - - data = [{ - "uuid": - str(uuid.uuid4()), - "schemaId": - tool['featureSchemaId'], - "dataRow": { - "id": data_row.uid - }, - "bbox": { - "top": 20, - "left": 20, - "height": 50, - "width": 50 - }, - "classifications": [{ - "schemaId": - nested_check["featureSchemaId"], - "answers": [ - { - "schemaId": nested_check["options"][0]["featureSchemaId"] - }, - { - "schemaId": nested_check["options"][1]["featureSchemaId"] - }, - ] - }] - }] - task = LabelImport.create_from_objects(client, project.uid, - f'label-import-{uuid.uuid4()}', data) - task.wait_until_done() - labels = project.label_generator() - object_annotation = [ - annot for annot in next(labels).annotations - if isinstance(annot, ObjectAnnotation) - ][0] - - nested_class_answers = object_annotation.classifications[0].value.answer - assert len(nested_class_answers) == 2 - - -def test_export_filtered_dates(client, - configured_project_with_complex_ontology): - project, data_row = configured_project_with_complex_ontology - ontology = project.ontology().normalized - - tool = ontology["tools"][0] - - data = [{ - "uuid": str(uuid.uuid4()), - "schemaId": tool['featureSchemaId'], - "dataRow": { - "id": data_row.uid - }, - "bbox": { - "top": 20, - "left": 20, - "height": 50, - "width": 50 - } - }] - - task = LabelImport.create_from_objects(client, project.uid, - f'label-import-{uuid.uuid4()}', data) - task.wait_until_done() - - regular_export = project.export_labels(download=True) - assert len(regular_export) == 1 - - filtered_export = project.export_labels(download=True, start="2020-01-01") - assert len(filtered_export) == 1 - - filtered_export_with_time = project.export_labels( - download=True, start="2020-01-01 00:00:01") - assert len(filtered_export_with_time) == 1 - - empty_export = project.export_labels(download=True, - start="2020-01-01", - end="2020-01-02") - assert len(empty_export) == 0 - - -def test_export_filtered_activity(client, - configured_project_with_complex_ontology): - project, data_row = configured_project_with_complex_ontology - ontology = project.ontology().normalized - - tool = ontology["tools"][0] - - data = [{ - "uuid": str(uuid.uuid4()), - "schemaId": tool['featureSchemaId'], - "dataRow": { - "id": data_row.uid - }, - "bbox": { - "top": 20, - "left": 20, - "height": 50, - "width": 50 - } - }] - - task = LabelImport.create_from_objects(client, project.uid, - f'label-import-{uuid.uuid4()}', data) - task.wait_until_done() - - regular_export = project.export_labels(download=True) - assert len(regular_export) == 1 - - filtered_export = project.export_labels( - download=True, - last_activity_start="2020-01-01", - last_activity_end=(datetime.datetime.now() + - datetime.timedelta(days=2)).strftime("%Y-%m-%d")) - assert len(filtered_export) == 1 - - filtered_export_with_time = project.export_labels( - download=True, last_activity_start="2020-01-01 00:00:01") - assert len(filtered_export_with_time) == 1 - - empty_export = project.export_labels( - download=True, - last_activity_start=(datetime.datetime.now() + - datetime.timedelta(days=2)).strftime("%Y-%m-%d"), - ) - - empty_export = project.export_labels( - download=True, - last_activity_end=(datetime.datetime.now() - - datetime.timedelta(days=1)).strftime("%Y-%m-%d")) - assert len(empty_export) == 0 - - -def test_export_data_rows(project: Project, dataset: Dataset): - n_data_rows = 2 - task = dataset.create_data_rows([ - { - "row_data": IMAGE_URL, - "external_id": "my-image" - }, - ] * n_data_rows) - task.wait_till_done() - - data_rows = [dr.uid for dr in list(dataset.export_data_rows())] - batch = project.create_batch("batch test", data_rows) - result = list(batch.export_data_rows()) - exported_data_rows = [dr.uid for dr in result] - - assert len(result) == n_data_rows - assert set(data_rows) == set(exported_data_rows) - - -def test_queued_data_row_export(configured_project): - result = configured_project.export_queued_data_rows() - assert len(result) == 1 - - -def test_label_export(configured_project_with_label): - project, _, _, label = configured_project_with_label - label_id = label.uid - # Wait for exporter to retrieve latest labels - time.sleep(10) - - # TODO: Move to export_v2 - exported_labels_url = project.export_labels() - assert exported_labels_url is not None - exported_labels = requests.get(exported_labels_url) - labels = [example['ID'] for example in exported_labels.json()] - assert labels[0] == label_id - #TODO: Add test for bulk export back. - # The new exporter doesn't work with the create_label mutation - - -def test_issues_export(project): - exported_issues_url = project.export_issues() - assert exported_issues_url - - exported_issues_url = project.export_issues("Open") - assert exported_issues_url - assert "?status=Open" in exported_issues_url - - exported_issues_url = project.export_issues("Resolved") - assert exported_issues_url - assert "?status=Resolved" in exported_issues_url - - invalidStatusValue = "Closed" - with pytest.raises(ValueError) as exc_info: - exported_issues_url = project.export_issues(invalidStatusValue) - assert "status must be in" in str(exc_info.value) - assert "Found %s" % (invalidStatusValue) in str(exc_info.value) - - -def test_dataset_export(dataset, image_url): - n_data_rows = 2 - ids = set() - for _ in range(n_data_rows): - ids.add(dataset.create_data_row(row_data=image_url)) - result = list(dataset.export_data_rows()) - assert len(result) == n_data_rows - assert set(result) == ids - - -def test_data_row_export_with_empty_media_attributes( - client, configured_project_with_label, wait_for_data_row_processing): - project, _, data_row, _ = configured_project_with_label - data_row = wait_for_data_row_processing(client, data_row) - labels = list(project.label_generator()) - assert len( - labels - ) == 1, "Label export job unexpectedly returned an empty result set`" - assert labels[0].data.media_attributes == {} - ----- -tests/integration/export/legacy/test_export_project.py -from datetime import datetime, timezone, timedelta - -import pytest -import uuid -from typing import Tuple - -from labelbox.schema.media_type import MediaType -from labelbox import Project, Dataset -from labelbox.schema.data_row import DataRow -from labelbox.schema.label import Label - -IMAGE_URL = "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/potato.jpeg" - - -def test_project_export_v2(client, export_v2_test_helpers, - configured_project_with_label, - wait_for_data_row_processing): - project, dataset, data_row, label = configured_project_with_label - data_row = wait_for_data_row_processing(client, data_row) - label_id = label.uid - - task_name = "test_label_export_v2" - params = { - "include_performance_details": True, - "include_labels": True, - "media_type_override": MediaType.Image, - "project_details": True, - "data_row_details": True - } - - task_results = export_v2_test_helpers.run_project_export_v2_task( - project, task_name=task_name, params=params) - - for task_result in task_results: - task_media_attributes = task_result['media_attributes'] - task_project = task_result['projects'][project.uid] - task_project_label_ids_set = set( - map(lambda prediction: prediction['id'], task_project['labels'])) - task_project_details = task_project['project_details'] - task_data_row = task_result['data_row'] - task_data_row_details = task_data_row['details'] - - assert label_id in task_project_label_ids_set - # data row - assert task_data_row['id'] == data_row.uid - assert task_data_row['external_id'] == data_row.external_id - assert task_data_row['row_data'] == data_row.row_data - - # data row details - assert task_data_row_details['dataset_id'] == dataset.uid - assert task_data_row_details['dataset_name'] == dataset.name - - actual_time = datetime.fromisoformat( - task_data_row_details['created_at']) - expected_time = datetime.fromisoformat( - dataset.created_at.strftime("%Y-%m-%dT%H:%M:%S.%f")) - actual_time = actual_time.replace(tzinfo=timezone.utc) - expected_time = expected_time.replace(tzinfo=timezone.utc) - tolerance = timedelta(seconds=2) - assert abs(actual_time - expected_time) <= tolerance - - assert task_data_row_details['last_activity_at'] is not None - assert task_data_row_details['created_by'] is not None - - # media attributes - assert task_media_attributes['mime_type'] == data_row.media_attributes[ - 'mimeType'] - - # project name and details - assert task_project['name'] == project.name - batch = next(project.batches()) - assert task_project_details['batch_id'] == batch.uid - assert task_project_details['batch_name'] == batch.name - assert task_project_details['priority'] is not None - assert task_project_details[ - 'consensus_expected_label_count'] is not None - assert task_project_details['workflow_history'] is not None - - # label details - assert task_project['labels'][0]['id'] == label_id - - -def test_project_export_v2_date_filters(client, export_v2_test_helpers, - configured_project_with_label, - wait_for_data_row_processing): - project, _, data_row, label = configured_project_with_label - data_row = wait_for_data_row_processing(client, data_row) - label_id = label.uid - - task_name = "test_label_export_v2_date_filters" - - filters = { - "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "task_queue_status": "InReview" - } - - # TODO: Right now we don't have a way to test this - include_performance_details = True - params = { - "performance_details": include_performance_details, - "include_labels": True, - "project_details": True, - "media_type_override": MediaType.Image - } - - task_queues = project.task_queues() - - review_queue = next( - tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE") - project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid) - - task_results = export_v2_test_helpers.run_project_export_v2_task( - project, task_name=task_name, filters=filters, params=params) - - for task_result in task_results: - task_project = task_result['projects'][project.uid] - task_project_label_ids_set = set( - map(lambda prediction: prediction['id'], task_project['labels'])) - assert label_id in task_project_label_ids_set - assert task_project['project_details']['workflow_status'] == 'IN_REVIEW' - - # TODO: Add back in when we have a way to test this - # if include_performance_details: - # assert 'include_performance_details' in task_result and task_result[ - # 'include_performance_details'] is not None - # else: - # assert 'include_performance_details' not in task_result or task_result[ - # 'include_performance_details'] is None - - filters = {"last_activity_at": [None, "2050-01-01 00:00:00"]} - export_v2_test_helpers.run_project_export_v2_task(project, filters=filters) - - filters = {"label_created_at": ["2000-01-01 00:00:00", None]} - export_v2_test_helpers.run_project_export_v2_task(project, filters=filters) - - -def test_project_export_v2_with_iso_date_filters(client, export_v2_test_helpers, - configured_project_with_label, - wait_for_data_row_processing): - project, _, data_row, label = configured_project_with_label - data_row = wait_for_data_row_processing(client, data_row) - label_id = label.uid - - task_name = "test_label_export_v2_with_iso_date_filters" - - filters = { - "last_activity_at": [ - "2000-01-01T00:00:00+0230", "2050-01-01T00:00:00+0230" - ], - "label_created_at": [ - "2000-01-01T00:00:00+0230", "2050-01-01T00:00:00+0230" - ] - } - task_results = export_v2_test_helpers.run_project_export_v2_task( - project, task_name=task_name, filters=filters) - assert label_id == task_results[0]['projects'][ - project.uid]['labels'][0]['id'] - - filters = {"last_activity_at": [None, "2050-01-01T00:00:00+0230"]} - task_results = export_v2_test_helpers.run_project_export_v2_task( - project, task_name=task_name, filters=filters) - assert label_id == task_results[0]['projects'][ - project.uid]['labels'][0]['id'] - - filters = {"label_created_at": ["2050-01-01T00:00:00+0230", None]} - task_results = export_v2_test_helpers.run_project_export_v2_task( - project, task_name=task_name, filters=filters) - assert len(task_results) == 0 - - -@pytest.mark.parametrize("data_rows", [3], indirect=True) -def test_project_export_v2_datarows_filter( - export_v2_test_helpers, - configured_batch_project_with_multiple_datarows): - project, _, data_rows = configured_batch_project_with_multiple_datarows - - data_row_ids = [dr.uid for dr in data_rows] - datarow_filter_size = 2 - - filters = { - "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "data_row_ids": data_row_ids[:datarow_filter_size] - } - params = {"data_row_details": True, "media_type_override": MediaType.Image} - task_results = export_v2_test_helpers.run_project_export_v2_task( - project, filters=filters, params=params) - - # only 2 datarows should be exported - assert len(task_results) == datarow_filter_size - # only filtered datarows should be exported - assert set([dr['data_row']['id'] for dr in task_results - ]) == set(data_row_ids[:datarow_filter_size]) - - global_keys = [dr.global_key for dr in data_rows] - filters = { - "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "global_keys": global_keys[:datarow_filter_size] - } - params = {"data_row_details": True, "media_type_override": MediaType.Image} - task_results = export_v2_test_helpers.run_project_export_v2_task( - project, filters=filters, params=params) - - # only 2 datarows should be exported - assert len(task_results) == datarow_filter_size - # only filtered datarows should be exported - assert set([dr['data_row']['global_key'] for dr in task_results - ]) == set(global_keys[:datarow_filter_size]) - - -def test_batch_project_export_v2( - configured_batch_project_with_label: Tuple[Project, Dataset, DataRow, - Label], - export_v2_test_helpers, dataset: Dataset, image_url: str): - project, dataset, *_ = configured_batch_project_with_label - - batch = list(project.batches())[0] - filters = { - "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - "batch_ids": [batch.uid], - } - params = { - "include_performance_details": True, - "include_labels": True, - "media_type_override": MediaType.Image - } - task_name = "test_batch_export_v2" - task = dataset.create_data_rows([ - { - "row_data": image_url, - "external_id": "my-image" - }, - ] * 2) - task.wait_till_done() - data_rows = [dr.uid for dr in list(dataset.export_data_rows())] - batch_one = f'batch one {uuid.uuid4()}' - - # This test creates two batches, only one batch should be exporter - # Creatin second batch that will not be used in the export due to the filter: batch_id - project.create_batch(batch_one, data_rows) - - task_results = export_v2_test_helpers.run_project_export_v2_task( - project, task_name=task_name, filters=filters, params=params) - assert (batch.size == len(task_results)) - ----- -tests/integration/export/legacy/test_export_model_run.py -import time - - -def _model_run_export_v2_results(model_run, task_name, params, num_retries=5): - """Export model run results and retry if no results are returned.""" - while (num_retries > 0): - task = model_run.export_v2(task_name, params=params) - assert task.name == task_name - task.wait_till_done() - assert task.status == "COMPLETE" - assert task.errors is None - task_results = task.result - if len(task_results) == 0: - num_retries -= 1 - time.sleep(5) - else: - return task_results - return [] - - -def test_model_run_export_v2(model_run_with_data_rows): - model_run, labels = model_run_with_data_rows - label_ids = [label.uid for label in labels] - expected_data_rows = list(model_run.model_run_data_rows()) - - task_name = "test_task" - params = {"media_attributes": True, "predictions": True} - task_results = _model_run_export_v2_results(model_run, task_name, params) - assert len(task_results) == len(expected_data_rows) - - for task_result in task_results: - # Check export param handling - assert 'media_attributes' in task_result and task_result[ - 'media_attributes'] is not None - exported_model_run = task_result['experiments'][ - model_run.model_id]['runs'][model_run.uid] - task_label_ids_set = set( - map(lambda label: label['id'], exported_model_run['labels'])) - task_prediction_ids_set = set( - map(lambda prediction: prediction['id'], - exported_model_run['predictions'])) - for label_id in task_label_ids_set: - assert label_id in label_ids - for prediction_id in task_prediction_ids_set: - assert prediction_id in label_ids - ----- -tests/integration/export/legacy/test_export_video.py -import time - -import pytest -import labelbox as lb -from labelbox.data.annotation_types.data.video import VideoData -import labelbox.types as lb_types -from labelbox.schema.annotation_import import AnnotationImportState - - -@pytest.fixture -def user_id(client): - return client.get_user().uid - - -@pytest.fixture -def org_id(client): - return client.get_organization().uid - - -def test_export_v2_video( - client, - configured_project_without_data_rows, - video_data, - video_data_row, - bbox_video_annotation_objects, - rand_gen, -): - - project = configured_project_without_data_rows - project_id = project.uid - labels = [] - - _, data_row_uids = video_data - project.create_batch( - rand_gen(str), - data_row_uids, # sample of data row objects - 5 # priority between 1(Highest) - 5(lowest) - ) - - for data_row_uid in data_row_uids: - labels = [ - lb_types.Label(data=VideoData(uid=data_row_uid), - annotations=bbox_video_annotation_objects) - ] - - label_import = lb.LabelImport.create_from_objects( - client, project_id, f'test-import-{project_id}', labels) - label_import.wait_until_done() - - assert label_import.state == AnnotationImportState.FINISHED - assert len(label_import.errors) == 0 - - num_retries = 5 - task = None - - while (num_retries > 0): - task = project.export_v2( - params={ - "performance_details": False, - "label_details": True, - "interpolated_frames": True - }) - task.wait_till_done() - assert task.status == "COMPLETE" - assert task.errors is None - if len(task.result) == 0: - num_retries -= 1 - time.sleep(5) - else: - break - - export_data = task.result - data_row_export = export_data[0]['data_row'] - assert data_row_export['global_key'] == video_data_row['global_key'] - assert data_row_export['row_data'] == video_data_row['row_data'] - assert export_data[0]['media_attributes']['mime_type'] == 'video/mp4' - assert export_data[0]['media_attributes'][ - 'frame_rate'] == 10 # as per the video_data fixture - assert export_data[0]['media_attributes'][ - 'frame_count'] == 100 # as per the video_data fixture - expected_export_label = { - 'label_kind': 'Video', - 'version': '1.0.0', - 'id': 'clgjnpysl000xi3zxtnp29fug', - 'label_details': { - 'created_at': '2023-04-16T17:04:23+00:00', - 'updated_at': '2023-04-16T17:04:23+00:00', - 'created_by': 'vbrodsky@labelbox.com', - 'content_last_updated_at': '2023-04-16T17:04:23+00:00', - 'reviews': [] - }, - 'annotations': { - 'frames': { - '13': { - 'objects': { - 'clgjnpyse000ui3zx6fr1d880': { - 'feature_id': 'clgjnpyse000ui3zx6fr1d880', - 'name': 'bbox', - 'annotation_kind': 'VideoBoundingBox', - 'classifications': [{ - 'feature_id': 'clgjnpyse000vi3zxtgtfh01y', - 'name': 'nested', - 'radio_answer': { - 'feature_id': 'clgjnpyse000wi3zxnxgv53ps', - 'name': 'radio_option_1', - 'classifications': [] - } - }], - 'bounding_box': { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - } - } - }, - 'classifications': [] - }, - '18': { - 'objects': { - 'clgjnpyse000ui3zx6fr1d880': { - 'feature_id': 'clgjnpyse000ui3zx6fr1d880', - 'name': 'bbox', - 'annotation_kind': 'VideoBoundingBox', - 'classifications': [{ - 'feature_id': 'clgjnpyse000vi3zxtgtfh01y', - 'name': 'nested', - 'radio_answer': { - 'feature_id': 'clgjnpyse000wi3zxnxgv53ps', - 'name': 'radio_option_1', - 'classifications': [] - } - }], - 'bounding_box': { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - } - } - }, - 'classifications': [] - }, - '19': { - 'objects': { - 'clgjnpyse000ui3zx6fr1d880': { - 'feature_id': 'clgjnpyse000ui3zx6fr1d880', - 'name': 'bbox', - 'annotation_kind': 'VideoBoundingBox', - 'classifications': [], - 'bounding_box': { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - } - } - }, - 'classifications': [] - } - }, - 'segments': { - 'clgjnpyse000ui3zx6fr1d880': [[13, 13], [18, 19]] - }, - 'key_frame_feature_map': { - 'clgjnpyse000ui3zx6fr1d880': { - '13': True, - '18': False, - '19': True - } - }, - 'classifications': [] - } - } - - project_export_labels = export_data[0]['projects'][project_id]['labels'] - assert (len(project_export_labels) == len(labels) - ) #note we create 1 label per data row, 1 data row so 1 label - export_label = project_export_labels[0] - assert (export_label['label_kind']) == 'Video' - - assert (export_label['label_details'].keys() - ) == expected_export_label['label_details'].keys() - - expected_frames_ids = [ - vannotation.frame for vannotation in bbox_video_annotation_objects - ] - export_annotations = export_label['annotations'] - export_frames = export_annotations['frames'] - export_frames_ids = [int(frame_id) for frame_id in export_frames.keys()] - all_frames_exported = [] - for value in expected_frames_ids: # note need to understand why we are exporting more frames than we created - if value not in export_frames_ids: - all_frames_exported.append(value) - assert (len(all_frames_exported) == 0) - - # BEGINNING OF THE VIDEO INTERPOLATION ASSERTIONS - first_frame_id = bbox_video_annotation_objects[0].frame - last_frame_id = bbox_video_annotation_objects[-1].frame - - # Generate list of frames with frames in between, e.g. 13, 14, 15, 16, 17, 18, 19 - expected_frame_ids = list(range(first_frame_id, last_frame_id + 1)) - - assert export_frames_ids == expected_frame_ids - - exported_objects_dict = export_frames[str(first_frame_id)]['objects'] - - # Get the label ID - first_exported_label_id = list(exported_objects_dict.keys())[0] - - # Since the bounding box moves to the right, the interpolated frame content should start a little bit more far to the right - assert export_frames[str(first_frame_id + 1)]['objects'][ - first_exported_label_id]['bounding_box']['left'] > export_frames[ - str(first_frame_id - )]['objects'][first_exported_label_id]['bounding_box']['left'] - # But it shouldn't be further than the last frame - assert export_frames[str(first_frame_id + 1)]['objects'][ - first_exported_label_id]['bounding_box']['left'] < export_frames[ - str(last_frame_id - )]['objects'][first_exported_label_id]['bounding_box']['left'] - # END OF THE VIDEO INTERPOLATION ASSERTIONS - - frame_with_nested_classifications = export_frames['13'] - annotation = None - for _, a in frame_with_nested_classifications['objects'].items(): - if a['name'] == 'bbox': - annotation = a - break - assert (annotation is not None) - assert (annotation['annotation_kind'] == 'VideoBoundingBox') - assert (annotation['classifications']) - assert (annotation['bounding_box'] == { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - }) - classifications = annotation['classifications'] - classification = classifications[0]['radio_answer'] - assert (classification['name'] == 'radio_option_1') - subclassifications = classification['classifications'] - # NOTE predictions services does not support nested classifications at the moment, see - # https://labelbox.atlassian.net/browse/AL-5588 - assert (len(subclassifications) == 0) - ----- -tests/integration/export/legacy/test_export_dataset.py -import pytest - - -@pytest.mark.parametrize('data_rows', [3], indirect=True) -def test_dataset_export_v2(export_v2_test_helpers, dataset, data_rows): - data_row_ids = [dr.uid for dr in data_rows] - params = {"performance_details": False, "label_details": False} - task_results = export_v2_test_helpers.run_dataset_export_v2_task( - dataset, params=params) - assert len(task_results) == len(data_row_ids) - assert set([dr['data_row']['id'] for dr in task_results - ]) == set(data_row_ids) - - # testing with a datarow ids filter - datarow_filter_size = 2 - data_row_ids = [dr.uid for dr in data_rows] - - params = {"performance_details": False, "label_details": False} - filters = {"data_row_ids": data_row_ids[:datarow_filter_size]} - - task_results = export_v2_test_helpers.run_dataset_export_v2_task( - dataset, filters=filters, params=params) - - # only 2 datarows should be exported - assert len(task_results) == datarow_filter_size - # only filtered datarows should be exported - assert set([dr['data_row']['id'] for dr in task_results - ]) == set(data_row_ids[:datarow_filter_size]) - - # testing with a global key and a datarow id filter - datarow_filter_size = 2 - global_keys = [dr.global_key for dr in data_rows] - - params = {"performance_details": False, "label_details": False} - filters = {"global_keys": global_keys[:datarow_filter_size]} - - task_results = export_v2_test_helpers.run_dataset_export_v2_task( - dataset, filters=filters, params=params) - - # only 2 datarows should be exported - assert len(task_results) == datarow_filter_size - # only filtered datarows should be exported - assert set([dr['data_row']['global_key'] for dr in task_results - ]) == set(global_keys[:datarow_filter_size]) - ----- -tests/data/__init__.py - ----- -tests/data/test_prefetch_generator.py -import pytest -from labelbox.data.generator import PrefetchGenerator -from random import random - - -class ChildClassGenerator(PrefetchGenerator): - - def __init__(self, examples, num_executors=1): - super().__init__(data=examples, num_executors=num_executors) - - def _process(self, value): - num = random() - if num < .2: - raise ValueError("Randomized value error") - return value - - -amount = (i for i in range(50)) - - -def test_single_thread_generator(): - generator = ChildClassGenerator(amount, num_executors=1) - - with pytest.raises(ValueError): - for _ in range(51): - next(generator) - - -def test_multi_thread_generator(): - generator = ChildClassGenerator(amount, num_executors=4) - - with pytest.raises(ValueError): - for _ in range(51): - next(generator) - ----- -tests/data/metrics/confusion_matrix/conftest.py -from types import SimpleNamespace - -import pytest - -from labelbox.data.annotation_types import ClassificationAnnotation, ObjectAnnotation -from labelbox.data.annotation_types import Polygon, Point, Rectangle, Mask, MaskData, Line, Radio, Text, Checklist, ClassificationAnswer -import numpy as np - -from labelbox.data.annotation_types.ner import TextEntity - - -class NameSpace(SimpleNamespace): - - def __init__(self, - predictions, - ground_truths, - expected, - expected_without_subclasses=None): - super(NameSpace, self).__init__( - predictions=predictions, - ground_truths=ground_truths, - expected=expected, - expected_without_subclasses=expected_without_subclasses or expected) - - -def get_radio(name, answer_name): - return ClassificationAnnotation( - name=name, value=Radio(answer=ClassificationAnswer(name=answer_name))) - - -def get_text(name, text_content): - return ClassificationAnnotation(name=name, value=Text(answer=text_content)) - - -def get_checklist(name, answer_names): - return ClassificationAnnotation(name=name, - value=Radio(answer=[ - ClassificationAnswer(name=answer_name) - for answer_name in answer_names - ])) - - -def get_polygon(name, points, subclasses=None): - return ObjectAnnotation( - name=name, - value=Polygon(points=[Point(x=x, y=y) for x, y in points]), - classifications=[] if subclasses is None else subclasses) - - -def get_rectangle(name, start, end, subclasses=None): - return ObjectAnnotation( - name=name, - value=Rectangle(start=Point(x=start[0], y=start[1]), - end=Point(x=end[0], y=end[1])), - classifications=[] if subclasses is None else subclasses) - - -def get_mask(name, pixels, color=(1, 1, 1), subclasses=None): - mask = np.zeros((32, 32, 3)).astype(np.uint8) - for pixel in pixels: - mask[pixel[0], pixel[1]] = color - return ObjectAnnotation( - name=name, - value=Mask(mask=MaskData(arr=mask), color=color), - classifications=[] if subclasses is None else subclasses) - - -def get_line(name, points, subclasses=None): - return ObjectAnnotation( - name=name, - value=Line(points=[Point(x=x, y=y) for x, y in points]), - classifications=[] if subclasses is None else subclasses) - - -def get_point(name, x, y, subclasses=None): - return ObjectAnnotation( - name=name, - value=Point(x=x, y=y), - classifications=[] if subclasses is None else subclasses) - - -def get_radio(name, answer_name): - return ClassificationAnnotation( - name=name, value=Radio(answer=ClassificationAnswer(name=answer_name))) - - -def get_checklist(name, answer_names): - return ClassificationAnnotation(name=name, - value=Checklist(answer=[ - ClassificationAnswer(name=answer_name) - for answer_name in answer_names - ])) - - -def get_ner(name, start, end, subclasses=None): - return ObjectAnnotation( - name=name, - value=TextEntity(start=start, end=end), - classifications=[] if subclasses is None else subclasses) - - -def get_object_pairs(tool_fn, **kwargs): - return [ - NameSpace(predictions=[tool_fn("cat", **kwargs)], - ground_truths=[tool_fn("cat", **kwargs)], - expected={'cat': [1, 0, 0, 0]}), - NameSpace( - predictions=[ - tool_fn("cat", - **kwargs, - subclasses=[get_radio("is_animal", answer_name="yes")]) - ], - ground_truths=[ - tool_fn("cat", - **kwargs, - subclasses=[get_radio("is_animal", answer_name="yes")]) - ], - expected={'cat': [1, 0, 0, 0]}, - expected_without_subclasses={'cat': [1, 0, 0, 0]}), - NameSpace(predictions=[ - tool_fn("cat", - **kwargs, - subclasses=[get_radio("is_animal", answer_name="yes")]) - ], - ground_truths=[ - tool_fn( - "cat", - **kwargs, - subclasses=[get_radio("is_animal", answer_name="no")]) - ], - expected={'cat': [0, 1, 0, 1]}, - expected_without_subclasses={'cat': [1, 0, 0, 0]}), - NameSpace(predictions=[ - tool_fn("cat", - **kwargs, - subclasses=[get_radio("is_animal", answer_name="yes")]), - tool_fn("cat", - **kwargs, - subclasses=[get_radio("is_animal", answer_name="no")]) - ], - ground_truths=[ - tool_fn( - "cat", - **kwargs, - subclasses=[get_radio("is_animal", answer_name="no")]) - ], - expected={'cat': [1, 1, 0, 0]}, - expected_without_subclasses={'cat': [1, 1, 0, 0]}), - NameSpace(predictions=[ - tool_fn("cat", - **kwargs, - subclasses=[get_radio("is_animal", answer_name="yes")]), - tool_fn("dog", - **kwargs, - subclasses=[get_radio("is_animal", answer_name="no")]) - ], - ground_truths=[ - tool_fn( - "cat", - **kwargs, - subclasses=[get_radio("is_animal", answer_name="no")]) - ], - expected={ - 'cat': [0, 1, 0, 1], - 'dog': [0, 1, 0, 0] - }, - expected_without_subclasses={ - 'cat': [1, 0, 0, 0], - 'dog': [0, 1, 0, 0] - }), - NameSpace( - predictions=[tool_fn("cat", **kwargs), - tool_fn("cat", **kwargs)], - ground_truths=[tool_fn("cat", **kwargs), - tool_fn("cat", **kwargs)], - expected={'cat': [2, 0, 0, 0]}), - NameSpace( - predictions=[tool_fn("cat", **kwargs), - tool_fn("cat", **kwargs)], - ground_truths=[tool_fn("cat", **kwargs)], - expected={'cat': [1, 1, 0, 0]}), - NameSpace( - predictions=[tool_fn("cat", **kwargs)], - ground_truths=[tool_fn("cat", **kwargs), - tool_fn("cat", **kwargs)], - expected={'cat': [1, 0, 0, 1]}), - NameSpace(predictions=[], - ground_truths=[], - expected=[], - expected_without_subclasses=[]), - NameSpace(predictions=[], - ground_truths=[tool_fn("cat", **kwargs)], - expected={'cat': [0, 0, 0, 1]}), - NameSpace(predictions=[tool_fn("cat", **kwargs)], - ground_truths=[], - expected={'cat': [0, 1, 0, 0]}), - NameSpace(predictions=[tool_fn("cat", **kwargs)], - ground_truths=[tool_fn("dog", **kwargs)], - expected={ - 'cat': [0, 1, 0, 0], - 'dog': [0, 0, 0, 1] - }) - ] - - -@pytest.fixture -def radio_pairs(): - return [ - NameSpace(predictions=[get_radio("is_animal", answer_name="yes")], - ground_truths=[get_radio("is_animal", answer_name="yes")], - expected={'yes': [1, 0, 0, 0]}), - NameSpace(predictions=[get_radio("is_animal", answer_name="yes")], - ground_truths=[get_radio("is_animal", answer_name="no")], - expected={ - 'no': [0, 0, 0, 1], - 'yes': [0, 1, 0, 0] - }), - NameSpace(predictions=[get_radio("is_animal", answer_name="yes")], - ground_truths=[], - expected={'yes': [0, 1, 0, 0]}), - NameSpace(predictions=[], - ground_truths=[get_radio("is_animal", answer_name="yes")], - expected={'yes': [0, 0, 0, 1]}), - NameSpace(predictions=[ - get_radio("is_animal", answer_name="yes"), - get_radio("is_short", answer_name="no") - ], - ground_truths=[get_radio("is_animal", answer_name="yes")], - expected={ - 'no': [0, 1, 0, 0], - 'yes': [1, 0, 0, 0] - }), - #Not supported yet: - # NameSpace( - #predictions=[], - #ground_truths=[], - #expected = [0,0,1,0] - #) - ] - - -@pytest.fixture -def checklist_pairs(): - return [ - NameSpace(predictions=[ - get_checklist("animal_attributes", answer_names=["striped"]) - ], - ground_truths=[ - get_checklist("animal_attributes", - answer_names=["striped"]) - ], - expected={'striped': [1, 0, 0, 0]}), - NameSpace(predictions=[ - get_checklist("animal_attributes", answer_names=["striped"]) - ], - ground_truths=[], - expected={'striped': [0, 1, 0, 0]}), - NameSpace(predictions=[], - ground_truths=[ - get_checklist("animal_attributes", - answer_names=["striped"]) - ], - expected={'striped': [0, 0, 0, 1]}), - NameSpace(predictions=[ - get_checklist("animal_attributes", - answer_names=["striped", "short"]) - ], - ground_truths=[ - get_checklist("animal_attributes", - answer_names=["striped"]) - ], - expected={ - 'short': [0, 1, 0, 0], - 'striped': [1, 0, 0, 0] - }), - NameSpace(predictions=[ - get_checklist("animal_attributes", answer_names=["striped"]) - ], - ground_truths=[ - get_checklist("animal_attributes", - answer_names=["striped", "short"]) - ], - expected={ - 'short': [0, 0, 0, 1], - 'striped': [1, 0, 0, 0] - }), - NameSpace(predictions=[ - get_checklist("animal_attributes", - answer_names=["striped", "short", "black"]) - ], - ground_truths=[ - get_checklist("animal_attributes", - answer_names=["striped", "short"]) - ], - expected={ - 'black': [0, 1, 0, 0], - 'short': [1, 0, 0, 0], - 'striped': [1, 0, 0, 0] - }), - NameSpace(predictions=[ - get_checklist("animal_attributes", - answer_names=["striped", "short", "black"]), - get_checklist("animal_name", answer_names=["doggy", "pup"]) - ], - ground_truths=[ - get_checklist("animal_attributes", - answer_names=["striped", "short"]), - get_checklist("animal_name", answer_names=["pup"]) - ], - expected={ - 'black': [0, 1, 0, 0], - 'doggy': [0, 1, 0, 0], - 'pup': [1, 0, 0, 0], - 'short': [1, 0, 0, 0], - 'striped': [1, 0, 0, 0] - }) - - #Not supported yet: - # NameSpace( - #predictions=[], - #ground_truths=[], - #expected = [0,0,1,0] - #) - ] - - -@pytest.fixture -def polygon_pairs(): - return get_object_pairs(get_polygon, - points=[[0, 0], [10, 0], [10, 10], [0, 10]]) - - -@pytest.fixture -def rectangle_pairs(): - return get_object_pairs(get_rectangle, start=[0, 0], end=[10, 10]) - - -@pytest.fixture -def mask_pairs(): - return get_object_pairs(get_mask, pixels=[[0, 0]]) - - -@pytest.fixture -def line_pairs(): - return get_object_pairs(get_line, - points=[[0, 0], [10, 0], [10, 10], [0, 10]]) - - -@pytest.fixture -def point_pairs(): - return get_object_pairs(get_point, x=0, y=0) - - -@pytest.fixture -def ner_pairs(): - return get_object_pairs(get_ner, start=0, end=10) - - -@pytest.fixture() -def pair_iou_thresholds(): - return [ - NameSpace(predictions=[ - get_polygon("cat", points=[[0, 0], [10, 0], [10, 10], [0, 10]]), - ], - ground_truths=[ - get_polygon("cat", - points=[[0, 0], [5, 0], [5, 5], [0, 5]]), - ], - expected={ - 0.2: [1, 0, 0, 0], - 0.3: [0, 1, 0, 1] - }), - NameSpace( - predictions=[get_rectangle("cat", start=[0, 0], end=[10, 10])], - ground_truths=[get_rectangle("cat", start=[0, 0], end=[5, 5])], - expected={ - 0.2: [1, 0, 0, 0], - 0.3: [0, 1, 0, 1] - }), - NameSpace(predictions=[get_point("cat", x=0, y=0)], - ground_truths=[get_point("cat", x=20, y=20)], - expected={ - 0.5: [1, 0, 0, 0], - 0.65: [0, 1, 0, 1] - }), - NameSpace(predictions=[ - get_line("cat", points=[[0, 0], [10, 0], [10, 10], [0, 10]]) - ], - ground_truths=[ - get_line("cat", - points=[[0, 0], [100, 0], [100, 100], [0, 100]]) - ], - expected={ - 0.3: [1, 0, 0, 0], - 0.65: [0, 1, 0, 1] - }), - NameSpace(predictions=[ - get_mask("cat", pixels=[[0, 0], [1, 1], [2, 2], [3, 3]]) - ], - ground_truths=[get_mask("cat", pixels=[[0, 0], [1, 1]])], - expected={ - 0.4: [1, 0, 0, 0], - 0.6: [0, 1, 0, 1] - }), - ] - ----- -tests/data/metrics/confusion_matrix/test_confusion_matrix_data_row.py -from pytest_cases import fixture_ref -from pytest_cases import parametrize, fixture_ref - -from labelbox.data.metrics.confusion_matrix.confusion_matrix import confusion_matrix_metric - - -@parametrize("tool_examples", [ - fixture_ref('polygon_pairs'), - fixture_ref('rectangle_pairs'), - fixture_ref('mask_pairs'), - fixture_ref('line_pairs'), - fixture_ref('point_pairs'), - fixture_ref('ner_pairs') -]) -def test_overlapping_objects(tool_examples): - for example in tool_examples: - - for include_subclasses, expected_attr_name in [[ - True, 'expected' - ], [False, 'expected_without_subclasses']]: - score = confusion_matrix_metric( - example.ground_truths, - example.predictions, - include_subclasses=include_subclasses) - - if len(getattr(example, expected_attr_name)) == 0: - assert len(score) == 0 - else: - expected = [0, 0, 0, 0] - for expected_values in getattr(example, - expected_attr_name).values(): - for idx in range(4): - expected[idx] += expected_values[idx] - assert score[0].value == tuple( - expected), f"{example.predictions},{example.ground_truths}" - - -@parametrize("tool_examples", - [fixture_ref('checklist_pairs'), - fixture_ref('radio_pairs')]) -def test_overlapping_classifications(tool_examples): - for example in tool_examples: - score = confusion_matrix_metric(example.ground_truths, - example.predictions) - if len(example.expected) == 0: - assert len(score) == 0 - else: - expected = [0, 0, 0, 0] - for expected_values in example.expected.values(): - for idx in range(4): - expected[idx] += expected_values[idx] - assert score[0].value == tuple( - expected), f"{example.predictions},{example.ground_truths}" - - -def test_partial_overlap(pair_iou_thresholds): - for example in pair_iou_thresholds: - for iou in example.expected.keys(): - score = confusion_matrix_metric(example.predictions, - example.ground_truths, - iou=iou) - assert score[0].value == tuple( - example.expected[iou] - ), f"{example.predictions},{example.ground_truths}" - ----- -tests/data/metrics/confusion_matrix/test_confusion_matrix_feature.py -from pytest_cases import fixture_ref -from pytest_cases import parametrize, fixture_ref - -from labelbox.data.metrics.confusion_matrix.confusion_matrix import feature_confusion_matrix_metric - - -@parametrize("tool_examples", [ - fixture_ref('polygon_pairs'), - fixture_ref('rectangle_pairs'), - fixture_ref('mask_pairs'), - fixture_ref('line_pairs'), - fixture_ref('point_pairs'), - fixture_ref('ner_pairs') -]) -def test_overlapping_objects(tool_examples): - for example in tool_examples: - for include_subclasses, expected_attr_name in [[ - True, 'expected' - ], [False, 'expected_without_subclasses']]: - metrics = feature_confusion_matrix_metric( - example.ground_truths, - example.predictions, - include_subclasses=include_subclasses) - - metrics = {r.feature_name: list(r.value) for r in metrics} - if len(getattr(example, expected_attr_name)) == 0: - assert len(metrics) == 0 - else: - assert metrics == getattr( - example, expected_attr_name - ), f"{example.predictions},{example.ground_truths}" - - -@parametrize("tool_examples", - [fixture_ref('checklist_pairs'), - fixture_ref('radio_pairs')]) -def test_overlapping_classifications(tool_examples): - for example in tool_examples: - - metrics = feature_confusion_matrix_metric(example.ground_truths, - example.predictions) - - metrics = {r.feature_name: list(r.value) for r in metrics} - if len(example.expected) == 0: - assert len(metrics) == 0 - else: - assert metrics == example.expected, f"{example.predictions},{example.ground_truths}" - ----- -tests/data/metrics/iou/data_row/conftest.py -from io import BytesIO -from types import SimpleNamespace -import pytest -import numpy as np -from PIL import Image -import base64 - - -class NameSpace(SimpleNamespace): - - def __init__(self, - predictions, - labels, - expected, - expected_without_subclasses=None, - data_row_expected=None, - media_attributes=None, - metadata=None, - classifications=None): - super(NameSpace, self).__init__( - predictions=predictions, - labels={ - 'DataRow ID': 'ckppihxc10005aeyjen11h7jh', - 'Labeled Data': "https://.jpg", - 'Media Attributes': media_attributes or {}, - 'DataRow Metadata': metadata or [], - 'Label': { - 'objects': labels, - 'classifications': classifications or [] - } - }, - expected=expected, - expected_without_subclasses=expected_without_subclasses or expected, - data_row_expected=data_row_expected) - - -@pytest.fixture -def polygon_pair(): - return NameSpace(labels=[{ - 'featureId': - 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'polygon': [{ - 'x': 0, - 'y': 0 - }, { - 'x': 1, - 'y': 0 - }, { - 'x': 1, - 'y': 1 - }, { - 'x': 0, - 'y': 1 - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'polygon': [{ - 'x': 0, - 'y': 0 - }, { - 'x': 1, - 'y': 0 - }, { - 'x': 1, - 'y': 0.5 - }, { - 'x': 0, - 'y': 0.5 - }] - }], - expected=0.5) - - -@pytest.fixture -def box_pair(): - return NameSpace(labels=[{ - 'featureId': 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "bbox": { - "top": 1099, - "left": 2010, - "height": 690, - "width": 591 - } - }], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - "bbox": { - "top": 1099, - "left": 2010, - "height": 690, - "width": 591 - } - }], - expected=1.0) - - -@pytest.fixture -def unmatched_prediction(): - return NameSpace(labels=[{ - 'featureId': - 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'polygon': [{ - 'x': 0, - 'y': 0 - }, { - 'x': 1, - 'y': 0 - }, { - 'x': 1, - 'y': 1 - }, { - 'x': 0, - 'y': 1 - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'polygon': [{ - 'x': 0, - 'y': 0 - }, { - 'x': 1, - 'y': 0 - }, { - 'x': 1, - 'y': 0.5 - }, { - 'x': 0, - 'y': 0.5 - }] - }, { - 'uuid': - 'd0ba2520-02e9-47d4-8736-088bbdbabbc3', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'polygon': [{ - 'x': 10, - 'y': 10 - }, { - 'x': 11, - 'y': 10 - }, { - 'x': 11, - 'y': 1.5 - }, { - 'x': 10, - 'y': 1.5 - }] - }], - expected=0.25) - - -@pytest.fixture -def unmatched_label(): - return NameSpace(labels=[{ - 'featureId': - '1234567890111213141516171', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'polygon': [{ - 'x': 0, - 'y': 0 - }, { - 'x': 1, - 'y': 0 - }, { - 'x': 1, - 'y': 1 - }, { - 'x': 0, - 'y': 1 - }] - }, { - 'featureId': - 'ckppiw3bs0007aeyjs3pvrqzi', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'polygon': [{ - 'x': 10, - 'y': 10 - }, { - 'x': 11, - 'y': 10 - }, { - 'x': 11, - 'y': 11 - }, { - 'x': 10, - 'y': 11 - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'polygon': [{ - 'x': 0, - 'y': 0 - }, { - 'x': 1, - 'y': 0 - }, { - 'x': 1, - 'y': 0.5 - }, { - 'x': 0, - 'y': 0.5 - }] - }], - expected=0.25) - - -def create_mask_url(indices, h, w, value): - mask = np.zeros((h, w, 3), dtype=np.uint8) - for idx in indices: - mask[idx] = value - return base64.b64encode(mask.tobytes()).decode('utf-8') - - -@pytest.fixture -def mask_pair(): - return NameSpace(labels=[{ - 'featureId': - '1234567890111213141516171', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'instanceURI': - create_mask_url([(0, 0), (0, 1)], 32, 32, (255, 255, 255)) - }], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'mask': { - 'instanceURI': - create_mask_url([(0, 0)], 32, 32, (1, 1, 1)), - 'colorRGB': (1, 1, 1) - } - }], - expected=0.5) - - -@pytest.fixture -def matching_radio(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': '1234567890111213141516171', - 'schemaId': 'ckrm02no8000008l3arwp6h4f', - 'answer': { - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t' - } - }], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckrm02no8000008l3arwp6h4f', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answer': { - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t' - } - }], - expected=1.) - - -@pytest.fixture -def empty_radio_label(): - return NameSpace(labels=[], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answer': { - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t' - } - }], - expected=0) - - -@pytest.fixture -def empty_radio_prediction(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': '1234567890111213141516171', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': { - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t' - } - }], - predictions=[], - expected=0) - - -@pytest.fixture -def matching_checklist(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': - '1234567890111213141516171', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }, { - 'schemaId': 'ckppidq4u0002aeyjmcc4toxw' - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }, { - 'schemaId': 'ckppidq4u0002aeyjmcc4toxw' - }] - }], - data_row_expected=1., - expected={1.0: 3}) - - -@pytest.fixture -def partially_matching_checklist_1(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': - '1234567890111213141516171', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }, { - 'schemaId': 'ckppidq4u0002aeyjmcc4toxw' - }, { - 'schemaId': 'ckppie29m0003aeyjk1ixzcom' - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }, { - 'schemaId': 'ckppidq4u0002aeyjmcc4toxw' - }, { - 'schemaId': 'ckppiebx80004aeyjuwvos69e' - }] - }], - data_row_expected=0.6, - expected={ - 0.0: 2, - 1.0: 3 - }) - - -@pytest.fixture -def partially_matching_checklist_2(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': - '1234567890111213141516171', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }, { - 'schemaId': 'ckppidq4u0002aeyjmcc4toxw' - }, { - 'schemaId': 'ckppiebx80004aeyjuwvos69e' - }] - }], - data_row_expected=0.5, - expected={ - 1.0: 2, - 0.0: 2 - }) - - -@pytest.fixture -def partially_matching_checklist_3(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': - '1234567890111213141516171', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }, { - 'schemaId': 'ckppidq4u0002aeyjmcc4toxw' - }, { - 'schemaId': 'ckppiebx80004aeyjuwvos69e' - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - }, { - 'schemaId': 'ckppide010001aeyj0yhiaghc' - }] - }], - data_row_expected=0.5, - expected={ - 1.0: 2, - 0.0: 2 - }) - - -@pytest.fixture -def empty_checklist_label(): - return NameSpace(labels=[], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t' - }] - }], - data_row_expected=0.0, - expected={0.0: 1}) - - -@pytest.fixture -def empty_checklist_prediction(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': '1234567890111213141516171', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answers': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t' - }] - }], - predictions=[], - data_row_expected=0.0, - expected={0.0: 1}) - - -@pytest.fixture -def matching_text(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': '1234567890111213141516171', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': 'test' - }], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answer': 'test' - }], - expected=1.0) - - -@pytest.fixture -def not_matching_text(): - return NameSpace(labels=[], - classifications=[{ - 'featureId': '1234567890111213141516171', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': 'test' - }], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'answer': 'not_test' - }], - expected=0.) - - -@pytest.fixture -def test_box_with_subclass(): - return NameSpace(labels=[{ - 'featureId': - 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - "bbox": { - "top": 1099, - "left": 2010, - "height": 690, - "width": 591 - }, - 'classifications': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': 'test' - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - "bbox": { - "top": 1099, - "left": 2010, - "height": 690, - "width": 591 - }, - 'classifications': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': 'test' - }] - }], - expected=1.0) - - -@pytest.fixture -def test_box_with_wrong_subclass(): - return NameSpace(labels=[{ - 'featureId': - 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - "bbox": { - "top": 1099, - "left": 2010, - "height": 690, - "width": 591 - }, - 'classifications': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': 'test' - }] - }], - predictions=[{ - 'uuid': - '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'schemaId': - 'ckppid25v0000aeyjmxfwlc7t', - "bbox": { - "top": 1099, - "left": 2010, - "height": 690, - "width": 591 - }, - 'classifications': [{ - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'answer': 'not_test' - }] - }], - expected=0.5, - expected_without_subclasses=1.0) - - -@pytest.fixture -def line_pair(): - return NameSpace(labels=[{ - 'featureId': 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "line": [{ - "x": 0, - "y": 100 - }, { - "x": 0, - "y": 0 - }], - }], - predictions=[{ - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - "line": [{ - "x": 5, - "y": 95 - }, { - "x": 0, - "y": 0 - }], - }], - expected=0.9496975567603978) - - -@pytest.fixture -def point_pair(): - return NameSpace(labels=[{ - 'featureId': 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "point": { - 'x': 0, - 'y': 0 - } - }], - predictions=[{ - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "point": { - 'x': 5, - 'y': 5 - } - }], - expected=0.879113232477017) - - -@pytest.fixture -def matching_ner(): - return NameSpace(labels=[{ - 'featureId': 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'format': "text.location", - 'data': { - "location": { - "start": 0, - "end": 10 - } - } - }], - predictions=[{ - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "location": { - "start": 0, - "end": 10 - } - }], - expected=1) - - -@pytest.fixture -def no_matching_ner(): - return NameSpace(labels=[{ - 'featureId': 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'format': "text.location", - 'data': { - "location": { - "start": 0, - "end": 5 - } - } - }], - predictions=[{ - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "location": { - "start": 5, - "end": 10 - } - }], - expected=0) - - -@pytest.fixture -def partial_matching_ner(): - return NameSpace(labels=[{ - 'featureId': 'ckppivl7p0006aeyj92cezr9d', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - 'format': "text.location", - 'data': { - "location": { - "start": 0, - "end": 7 - } - } - }], - predictions=[{ - 'dataRow': { - 'id': 'ckppihxc10005aeyjen11h7jh' - }, - 'uuid': '76e0dcea-fe46-43e5-95f5-a5e3f378520a', - 'schemaId': 'ckppid25v0000aeyjmxfwlc7t', - "location": { - "start": 3, - "end": 5 - } - }], - expected=0.2857142857142857) - ----- -tests/data/metrics/iou/data_row/test_data_row_iou.py -from labelbox.data.metrics.iou.iou import miou_metric -from pytest_cases import parametrize, fixture_ref -from unittest.mock import patch -import math -import numpy as np -import base64 - -from labelbox.data.metrics.iou import data_row_miou, feature_miou_metric -from labelbox.data.serialization import NDJsonConverter, LBV1Converter -from labelbox.data.annotation_types import Label, ImageData, Mask - - -def check_iou(pair, mask=None): - default = Label(data=ImageData( - uid="ckppihxc10005aeyjen11h7jh", media_attributes=None, metadata=None)) - prediction = next(NDJsonConverter.deserialize(pair.predictions), default) - label = next(LBV1Converter.deserialize([pair.labels])) - if mask: - for annotation in [*prediction.annotations, *label.annotations]: - if isinstance(annotation.value, Mask): - annotation.value.mask.arr = np.frombuffer( - base64.b64decode(annotation.value.mask.url.encode('utf-8')), - dtype=np.uint8).reshape((32, 32, 3)) - - for include_subclasses, expected_attr_name in [[ - True, 'expected' - ], [False, 'expected_without_subclasses']]: - assert math.isclose( - data_row_miou(label, - prediction, - include_subclasses=include_subclasses), - getattr(pair, expected_attr_name)) - assert math.isclose( - miou_metric(label.annotations, - prediction.annotations, - include_subclasses=include_subclasses)[0].value, - getattr(pair, expected_attr_name)) - feature_ious = feature_miou_metric( - label.annotations, - prediction.annotations, - include_subclasses=include_subclasses) - assert len( - feature_ious - ) == 1 # The tests run here should only have one class present. - assert math.isclose(feature_ious[0].value, - getattr(pair, expected_attr_name)) - - -def check_iou_checklist(pair, mask=None): - """specialized test since checklists have more than one feature ious """ - default = Label(data=ImageData(uid="ckppihxc10005aeyjen11h7jh")) - prediction = next(NDJsonConverter.deserialize(pair.predictions), default) - label = next(LBV1Converter.deserialize([pair.labels])) - if mask: - for annotation in [*prediction.annotations, *label.annotations]: - if isinstance(annotation.value, Mask): - annotation.value.mask.arr = np.frombuffer( - base64.b64decode(annotation.value.mask.url.encode('utf-8')), - dtype=np.uint8).reshape((32, 32, 3)) - assert math.isclose(data_row_miou(label, prediction), - pair.data_row_expected) - assert math.isclose( - miou_metric(label.annotations, prediction.annotations)[0].value, - pair.data_row_expected) - feature_ious = feature_miou_metric(label.annotations, - prediction.annotations) - mapping = {} - for iou in feature_ious: - if not mapping.get(iou.value, None): - mapping[iou.value] = 0 - mapping[iou.value] += 1 - assert mapping == pair.expected - - -def strings_to_fixtures(strings): - return [fixture_ref(x) for x in strings] - - -def test_overlapping(polygon_pair, box_pair, mask_pair): - check_iou(polygon_pair) - check_iou(box_pair) - check_iou(mask_pair, True) - - -@parametrize("pair", - strings_to_fixtures([ - "unmatched_label", - "unmatched_prediction", - ])) -def test_unmatched(pair): - check_iou(pair) - - -@parametrize( - "pair", - strings_to_fixtures([ - "empty_radio_label", - "matching_radio", - "empty_radio_prediction", - ])) -def test_radio(pair): - check_iou(pair) - - -@parametrize( - "pair", - strings_to_fixtures([ - "matching_checklist", - "partially_matching_checklist_1", - "partially_matching_checklist_2", - "partially_matching_checklist_3", - "empty_checklist_label", - "empty_checklist_prediction", - ])) -def test_checklist(pair): - check_iou_checklist(pair) - - -@parametrize("pair", strings_to_fixtures(["matching_text", - "not_matching_text"])) -def test_text(pair): - check_iou(pair) - - -@parametrize( - "pair", - strings_to_fixtures( - ["test_box_with_wrong_subclass", "test_box_with_subclass"])) -def test_vector_with_subclass(pair): - check_iou(pair) - - -@parametrize("pair", strings_to_fixtures(["point_pair", "line_pair"])) -def test_others(pair): - check_iou(pair) - - -@parametrize( - "pair", - strings_to_fixtures( - ["matching_ner", "no_matching_ner", "partial_matching_ner"])) -def test_ner(pair): - check_iou(pair) - ----- -tests/data/metrics/iou/feature/conftest.py -from types import SimpleNamespace - -import pytest - -from labelbox.data.annotation_types import ClassificationAnnotation, ObjectAnnotation -from labelbox.data.annotation_types import Polygon, Point - - -class NameSpace(SimpleNamespace): - - def __init__(self, predictions, ground_truths, expected): - super(NameSpace, self).__init__(predictions=predictions, - ground_truths=ground_truths, - expected=expected) - - -@pytest.fixture -def different_classes(): - return [ - NameSpace(predictions=[ - ObjectAnnotation(name="cat", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])) - ], - ground_truths=[ - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])) - ], - expected={ - 'cat': 0, - 'dog': 0 - }) - ] - - -@pytest.fixture -def one_overlap_class(): - return [ - NameSpace(predictions=[ - ObjectAnnotation(name="cat", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])), - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=5, y=0), - Point(x=5, y=5), - Point(x=0, y=5) - ])) - ], - ground_truths=[ - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])) - ], - expected={ - 'dog': 0.25, - 'cat': 0. - }), - NameSpace(predictions=[ - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=5, y=0), - Point(x=5, y=5), - Point(x=0, y=5) - ])) - ], - ground_truths=[ - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])), - ObjectAnnotation(name="cat", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])) - ], - expected={ - 'dog': 0.25, - 'cat': 0. - }) - ] - - -@pytest.fixture -def empty_annotations(): - return [ - NameSpace(predictions=[], ground_truths=[], expected={}), - NameSpace(predictions=[], - ground_truths=[ - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])), - ObjectAnnotation(name="cat", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])) - ], - expected={ - 'dog': 0., - 'cat': 0. - }), - NameSpace(predictions=[ - ObjectAnnotation(name="dog", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])), - ObjectAnnotation(name="cat", - value=Polygon(points=[ - Point(x=0, y=0), - Point(x=10, y=0), - Point(x=10, y=10), - Point(x=0, y=10) - ])) - ], - ground_truths=[], - expected={ - 'dog': 0., - 'cat': 0. - }) - ] - ----- -tests/data/metrics/iou/feature/test_feature_iou.py -import math - -from labelbox.data.metrics.iou.iou import miou_metric, feature_miou_metric - - -def check_iou(pair): - one_metrics = miou_metric(pair.predictions, pair.ground_truths) - metrics = feature_miou_metric(pair.predictions, pair.ground_truths) - result = {metric.feature_name: metric.value for metric in metrics} - assert len(set(pair.expected.keys()).difference(set(result.keys()))) == 0 - - for key in result: - assert math.isclose(result[key], pair.expected[key]) - - for metric in metrics: - assert metric.metric_name == "custom_iou" - - if len(pair.expected): - assert len(one_metrics) - one_metric = one_metrics[0] - assert one_metric.value == sum(list(pair.expected.values())) / len( - pair.expected) - - -def test_different_classes(different_classes): - for pair in different_classes: - check_iou(pair) - - -def test_empty_annotations(empty_annotations): - for pair in empty_annotations: - check_iou(pair) - - -def test_one_overlap_classes(one_overlap_class): - for pair in one_overlap_class: - check_iou(pair) - ----- -tests/data/serialization/__init__.py - ----- -tests/data/serialization/ndjson/test_rectangle.py -import json -import labelbox.types as lb_types -from labelbox.data.serialization.ndjson.converter import NDJsonConverter - -DATAROW_ID = "ckrb1sf1i1g7i0ybcdc6oc8ct" - - -def test_rectangle(): - with open('tests/data/assets/ndjson/rectangle_import.json', 'r') as file: - data = json.load(file) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - assert res == data - - -def test_rectangle_inverted_start_end_points(): - with open('tests/data/assets/ndjson/rectangle_import.json', 'r') as file: - data = json.load(file) - - bbox = lb_types.ObjectAnnotation( - name="bbox", - value=lb_types.Rectangle( - start=lb_types.Point(x=81, y=69), - end=lb_types.Point(x=38, y=28), - ), - extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}) - - label = lb_types.Label(data=lb_types.ImageData(uid=DATAROW_ID), - annotations=[bbox]) - - res = list(NDJsonConverter.serialize([label])) - assert res == data - - expected_bbox = lb_types.ObjectAnnotation( - name="bbox", - value=lb_types.Rectangle( - start=lb_types.Point(x=38, y=28), - end=lb_types.Point(x=81, y=69), - ), - extra={ - "uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72", - "page": None, - "unit": None - }) - - label = lb_types.Label(data=lb_types.ImageData(uid=DATAROW_ID), - annotations=[expected_bbox]) - - res = list(NDJsonConverter.deserialize(res)) - assert res == [label] - - -def test_rectangle_mixed_start_end_points(): - with open('tests/data/assets/ndjson/rectangle_import.json', 'r') as file: - data = json.load(file) - - bbox = lb_types.ObjectAnnotation( - name="bbox", - value=lb_types.Rectangle( - start=lb_types.Point(x=81, y=28), - end=lb_types.Point(x=38, y=69), - ), - extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}) - - label = lb_types.Label(data=lb_types.ImageData(uid=DATAROW_ID), - annotations=[bbox]) - - res = list(NDJsonConverter.serialize([label])) - assert res == data - - bbox = lb_types.ObjectAnnotation( - name="bbox", - value=lb_types.Rectangle( - start=lb_types.Point(x=38, y=28), - end=lb_types.Point(x=81, y=69), - ), - extra={ - "uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72", - "page": None, - "unit": None - }) - - label = lb_types.Label(data=lb_types.ImageData(uid=DATAROW_ID), - annotations=[bbox]) - - res = list(NDJsonConverter.deserialize(res)) - assert res == [label] - ----- -tests/data/serialization/ndjson/test_metric.py -import json - -from labelbox.data.serialization.labelbox_v1.converter import LBV1Converter -from labelbox.data.serialization.ndjson.converter import NDJsonConverter - - -def test_metric(): - with open('tests/data/assets/ndjson/metric_import.json', 'r') as file: - data = json.load(file) - - label_list = list(NDJsonConverter.deserialize(data)) - reserialized = list(NDJsonConverter.serialize(label_list)) - assert reserialized == data - - # Just make sure that this doesn't break - list(LBV1Converter.serialize(label_list)) - - -def test_custom_scalar_metric(): - with open('tests/data/assets/ndjson/custom_scalar_import.json', - 'r') as file: - data = json.load(file) - - label_list = list(NDJsonConverter.deserialize(data)) - reserialized = list(NDJsonConverter.serialize(label_list)) - assert json.dumps(reserialized, - sort_keys=True) == json.dumps(data, sort_keys=True) - - # Just make sure that this doesn't break - list(LBV1Converter.serialize(label_list)) - - -def test_custom_confusion_matrix_metric(): - with open('tests/data/assets/ndjson/custom_confusion_matrix_import.json', - 'r') as file: - data = json.load(file) - - label_list = list(NDJsonConverter.deserialize(data)) - reserialized = list(NDJsonConverter.serialize(label_list)) - assert json.dumps(reserialized, - sort_keys=True) == json.dumps(data, sort_keys=True) - - # Just make sure that this doesn't break - list(LBV1Converter.serialize(label_list)) - ----- -tests/data/serialization/ndjson/test_radio.py -import json -from labelbox.data.annotation_types.annotation import ClassificationAnnotation -from labelbox.data.annotation_types.classification.classification import ClassificationAnswer -from labelbox.data.annotation_types.classification.classification import Radio -from labelbox.data.annotation_types.data.text import TextData -from labelbox.data.annotation_types.label import Label - -from labelbox.data.serialization.ndjson.converter import NDJsonConverter - - -def test_serialization_with_radio_min(): - label = Label( - uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( - uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", - ), - annotations=[ - ClassificationAnnotation( - name="radio_question_geo", - value=Radio( - answer=ClassificationAnswer(name="first_radio_answer",))) - ]) - - expected = { - 'name': 'radio_question_geo', - 'answer': { - 'name': 'first_radio_answer' - }, - 'dataRow': { - 'id': 'bkj7z2q0b0000jx6x0q2q7q0d' - } - } - serialized = NDJsonConverter.serialize([label]) - res = next(serialized) - - res.pop("uuid") - assert res == expected - - deserialized = NDJsonConverter.deserialize([res]) - res = next(deserialized) - res.annotations[0].extra.pop("uuid") - assert res.annotations == label.annotations - - -def test_serialization_with_radio_classification(): - label = Label(uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( - uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", - ), - annotations=[ - ClassificationAnnotation( - name="radio_question_geo", - confidence=0.5, - value=Radio(answer=ClassificationAnswer( - confidence=0.6, - name="first_radio_answer", - classifications=[ - ClassificationAnnotation( - name="sub_radio_question", - value=Radio(answer=ClassificationAnswer( - name="first_sub_radio_answer"))) - ]))) - ]) - - expected = { - 'confidence': 0.5, - 'name': 'radio_question_geo', - 'answer': { - 'confidence': - 0.6, - 'name': - 'first_radio_answer', - 'classifications': [{ - 'name': 'sub_radio_question', - 'answer': { - 'name': 'first_sub_radio_answer', - } - }] - }, - 'dataRow': { - 'id': 'bkj7z2q0b0000jx6x0q2q7q0d' - } - } - - serialized = NDJsonConverter.serialize([label]) - res = next(serialized) - res.pop("uuid") - assert res == expected - - deserialized = NDJsonConverter.deserialize([res]) - res = next(deserialized) - res.annotations[0].extra.pop("uuid") - assert res.annotations == label.annotations - ----- -tests/data/serialization/ndjson/test_free_text.py -from labelbox.data.annotation_types.annotation import ClassificationAnnotation -from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnswer, Radio, Text -from labelbox.data.annotation_types.data.text import TextData -from labelbox.data.annotation_types.label import Label - -from labelbox.data.serialization.ndjson.converter import NDJsonConverter - - -def test_serialization(): - label = Label(uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( - uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", - ), - annotations=[ - ClassificationAnnotation(name="free_text_annotation", - value=Text(confidence=0.5, - answer="text_answer")) - ]) - - serialized = NDJsonConverter.serialize([label]) - res = next(serialized) - - assert res['confidence'] == 0.5 - assert res['name'] == "free_text_annotation" - assert res['answer'] == "text_answer" - assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d" - - deserialized = NDJsonConverter.deserialize([res]) - res = next(deserialized) - - annotation = res.annotations[0] - - annotation_value = annotation.value - assert type(annotation_value) is Text - assert annotation_value.answer == "text_answer" - assert annotation_value.confidence == 0.5 - - -def test_nested_serialization(): - label = Label( - uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( - uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", - ), - annotations=[ - ClassificationAnnotation( - name="nested test", - value=Checklist(answer=[ - ClassificationAnswer( - name="first_answer", - confidence=0.9, - classifications=[ - ClassificationAnnotation( - name="sub_radio_question", - value=Radio(answer=ClassificationAnswer( - name="first_sub_radio_answer", - confidence=0.8, - classifications=[ - ClassificationAnnotation( - name="nested answer", - value=Text( - answer="nested answer", - confidence=0.7, - )) - ]))) - ]) - ]), - ) - ]) - - serialized = NDJsonConverter.serialize([label]) - res = next(serialized) - - assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d" - answer = res['answer'][0] - assert answer['confidence'] == 0.9 - assert answer['name'] == "first_answer" - classification = answer['classifications'][0] - nested_classification_answer = classification['answer'] - assert nested_classification_answer['confidence'] == 0.8 - assert nested_classification_answer['name'] == "first_sub_radio_answer" - sub_classification = nested_classification_answer['classifications'][0] - assert sub_classification['name'] == "nested answer" - assert sub_classification['answer'] == "nested answer" - assert sub_classification['confidence'] == 0.7 - - deserialized = NDJsonConverter.deserialize([res]) - res = next(deserialized) - annotation = res.annotations[0] - answer = annotation.value.answer[0] - assert answer.confidence == 0.9 - assert answer.name == "first_answer" - - classification_answer = answer.classifications[0].value.answer - assert classification_answer.confidence == 0.8 - assert classification_answer.name == "first_sub_radio_answer" - - sub_classification_answer = classification_answer.classifications[0].value - assert type(sub_classification_answer) is Text - assert sub_classification_answer.answer == "nested answer" - assert sub_classification_answer.confidence == 0.7 - ----- -tests/data/serialization/ndjson/test_dicom.py -from copy import copy -import pytest -import base64 -import labelbox.types as lb_types -from labelbox.data.serialization import NDJsonConverter -from labelbox.data.serialization.ndjson.objects import NDDicomSegments, NDDicomSegment, NDDicomLine -""" -Polyline test data -""" - -dicom_polyline_annotations = [ - lb_types.DICOMObjectAnnotation(uuid="78a8a027-9089-420c-8348-6099eb77e4aa", - name="dicom_polyline", - frame=2, - value=lb_types.Line(points=[ - lb_types.Point(x=680, y=100), - lb_types.Point(x=100, y=190), - lb_types.Point(x=190, y=220) - ]), - segment_index=0, - keyframe=True, - group_key=lb_types.GroupKey.AXIAL) -] - -polyline_label = lb_types.Label(data=lb_types.DicomData(uid="test-uid"), - annotations=dicom_polyline_annotations) - -polyline_annotation_ndjson = { - 'classifications': [], - 'dataRow': { - 'id': 'test-uid' - }, - 'name': - 'dicom_polyline', - 'groupKey': - 'axial', - 'segments': [{ - 'keyframes': [{ - 'frame': 2, - 'line': [ - { - 'x': 680.0, - 'y': 100.0 - }, - { - 'x': 100.0, - 'y': 190.0 - }, - { - 'x': 190.0, - 'y': 220.0 - }, - ], - 'classifications': [], - }] - }], -} - -polyline_with_global_key = lb_types.Label( - data=lb_types.DicomData(global_key="test-global-key"), - annotations=dicom_polyline_annotations) - -polyline_annotation_ndjson_with_global_key = copy(polyline_annotation_ndjson) -polyline_annotation_ndjson_with_global_key['dataRow'] = { - 'globalKey': 'test-global-key' -} -""" -Video test data -""" - -instance_uri_1 = 'https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA' -instance_uri_5 = 'https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys1%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA' -frames = [ - lb_types.MaskFrame(index=1, instance_uri=instance_uri_1), - lb_types.MaskFrame(index=5, instance_uri=instance_uri_5) -] -instances = [ - lb_types.MaskInstance(color_rgb=(0, 0, 255), name="mask1"), - lb_types.MaskInstance(color_rgb=(0, 255, 0), name="mask2"), - lb_types.MaskInstance(color_rgb=(255, 0, 0), name="mask3") -] - -video_mask_annotation = lb_types.VideoMaskAnnotation(frames=frames, - instances=instances) - -video_mask_annotation_ndjson = { - 'dataRow': { - 'id': 'test-uid' - }, - 'masks': { - 'frames': [{ - 'index': 1, - 'imBytes': None, - 'instanceURI': instance_uri_1 - }, { - 'index': 5, - 'imBytes': None, - 'instanceURI': instance_uri_5 - }], - 'instances': [ - { - 'colorRGB': (0, 0, 255), - 'name': 'mask1' - }, - { - 'colorRGB': (0, 255, 0), - 'name': 'mask2' - }, - { - 'colorRGB': (255, 0, 0), - 'name': 'mask3' - }, - ] - }, -} - -video_mask_annotation_ndjson_with_global_key = copy( - video_mask_annotation_ndjson) -video_mask_annotation_ndjson_with_global_key['dataRow'] = { - 'globalKey': 'test-global-key' -} - -video_mask_label = lb_types.Label(data=lb_types.VideoData(uid="test-uid"), - annotations=[video_mask_annotation]) - -video_mask_label_with_global_key = lb_types.Label( - data=lb_types.VideoData(global_key="test-global-key"), - annotations=[video_mask_annotation]) -""" -DICOM Mask test data -""" - -dicom_mask_annotation = lb_types.DICOMMaskAnnotation( - name="dicom_mask", - group_key=lb_types.GroupKey.AXIAL, - frames=frames, - instances=instances) - -dicom_mask_label = lb_types.Label(data=lb_types.DicomData(uid="test-uid"), - annotations=[dicom_mask_annotation]) - -dicom_mask_label_with_global_key = lb_types.Label( - data=lb_types.DicomData(global_key="test-global-key"), - annotations=[dicom_mask_annotation]) - -dicom_mask_annotation_ndjson = copy(video_mask_annotation_ndjson) -dicom_mask_annotation_ndjson['groupKey'] = 'axial' -dicom_mask_annotation_ndjson_with_global_key = copy( - dicom_mask_annotation_ndjson) -dicom_mask_annotation_ndjson_with_global_key['dataRow'] = { - 'globalKey': 'test-global-key' -} -""" -Tests -""" - -labels = [ - polyline_label, polyline_with_global_key, dicom_mask_label, - dicom_mask_label_with_global_key, video_mask_label, - video_mask_label_with_global_key -] -ndjsons = [ - polyline_annotation_ndjson, - polyline_annotation_ndjson_with_global_key, - dicom_mask_annotation_ndjson, - dicom_mask_annotation_ndjson_with_global_key, - video_mask_annotation_ndjson, - video_mask_annotation_ndjson_with_global_key, -] -labels_ndjsons = list(zip(labels, ndjsons)) - - -def test_deserialize_nd_dicom_segments(): - nd_dicom_segments = NDDicomSegments(**polyline_annotation_ndjson) - assert isinstance(nd_dicom_segments, NDDicomSegments) - assert isinstance(nd_dicom_segments.segments[0], NDDicomSegment) - assert isinstance(nd_dicom_segments.segments[0].keyframes[0], NDDicomLine) - - -@pytest.mark.parametrize('label, ndjson', labels_ndjsons) -def test_serialize_label(label, ndjson): - serialized_label = next(NDJsonConverter().serialize([label])) - serialized_label.pop('uuid') - assert serialized_label == ndjson - - -@pytest.mark.parametrize('label, ndjson', labels_ndjsons) -def test_deserialize_label(label, ndjson): - deserialized_label = next(NDJsonConverter().deserialize([ndjson])) - if hasattr(deserialized_label.annotations[0], 'extra'): - deserialized_label.annotations[0].extra = {} - assert deserialized_label.annotations == label.annotations - - -@pytest.mark.parametrize('label', labels) -def test_serialize_deserialize_label(label): - serialized = list(NDJsonConverter.serialize([label])) - deserialized = list(NDJsonConverter.deserialize(serialized)) - if hasattr(deserialized[0].annotations[0], 'extra'): - deserialized[0].annotations[0].extra = {} - assert deserialized[0].annotations == label.annotations - ----- -tests/data/serialization/ndjson/test_image.py -import json -import numpy as np -import cv2 - -from labelbox.data.serialization.ndjson.converter import NDJsonConverter -from labelbox.data.annotation_types import Mask, Label, ObjectAnnotation, ImageData, MaskData - - -def round_dict(data): - if isinstance(data, dict): - for key in data: - if isinstance(data[key], float): - data[key] = int(data[key]) - elif isinstance(data[key], dict): - data[key] = round_dict(data[key]) - elif isinstance(data[key], (list, tuple)): - data[key] = [round_dict(r) for r in data[key]] - - return data - - -def test_image(): - with open('tests/data/assets/ndjson/image_import.json', 'r') as file: - data = json.load(file) - - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - - for r in res: - r.pop('classifications', None) - assert [round_dict(x) for x in res] == [round_dict(x) for x in data] - - -def test_image_with_name_only(): - with open('tests/data/assets/ndjson/image_import_name_only.json', - 'r') as file: - data = json.load(file) - - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - for r in res: - r.pop('classifications', None) - assert [round_dict(x) for x in res] == [round_dict(x) for x in data] - - -def test_mask(): - data = [{ - "uuid": "b862c586-8614-483c-b5e6-82810f70cac0", - "schemaId": "ckrazcueb16og0z6609jj7y3y", - "dataRow": { - "id": "ckrazctum0z8a0ybc0b0o0g0v" - }, - "mask": { - "png": - "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAAAAACoWZBhAAAAMklEQVR4nD3MuQ3AQADDMOqQ/Vd2ijytaSiZLAcYuyLEYYYl9cvrlGftTHvsYl+u/3EDv0QLI8Z7FlwAAAAASUVORK5CYII=" - }, - "confidence": 0.8, - "customMetrics": [{ - "name": "customMetric1", - "value": 0.4 - }], - }, { - "uuid": "751fc725-f7b6-48ed-89b0-dd7d94d08af6", - "schemaId": "ckrazcuec16ok0z66f956apb7", - "dataRow": { - "id": "ckrazctum0z8a0ybc0b0o0g0v" - }, - "mask": { - "instanceURI": - "https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY", - "colorRGB": [255, 0, 0] - } - }] - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - for r in res: - r.pop('classifications', None) - - assert [round_dict(x) for x in res] == [round_dict(x) for x in data] - - -def test_mask_from_arr(): - mask_arr = np.round(np.zeros((32, 32))).astype(np.uint8) - mask_arr = cv2.rectangle(mask_arr, (5, 5), (10, 10), (1, 1), -1) - - label = Label(annotations=[ - ObjectAnnotation(feature_schema_id="1" * 25, - value=Mask(mask=MaskData.from_2D_arr(arr=mask_arr), - color=(1, 1, 1))) - ], - data=ImageData(uid="0" * 25)) - res = next(NDJsonConverter.serialize([label])) - res.pop("uuid") - assert res == { - "classifications": [], - "schemaId": "1" * 25, - "dataRow": { - "id": "0" * 25 - }, - "mask": { - "png": - "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAAAAABWESUoAAAAHklEQVR4nGNgGAKAEYn8j00BEyETBoOCUTAKhhwAAJW+AQwvpePVAAAAAElFTkSuQmCC" - } - } - ----- -tests/data/serialization/ndjson/__init__.py - ----- -tests/data/serialization/ndjson/test_video.py -import json -from labelbox.client import Client -from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnnotation, ClassificationAnswer, Radio -from labelbox.data.annotation_types.data.video import VideoData -from labelbox.data.annotation_types.geometry.line import Line -from labelbox.data.annotation_types.geometry.point import Point -from labelbox.data.annotation_types.geometry.rectangle import Rectangle -from labelbox.data.annotation_types.geometry.point import Point - -from labelbox.data.annotation_types.label import Label -from labelbox.data.annotation_types.video import VideoObjectAnnotation -from labelbox import parser - -from labelbox.data.serialization.ndjson.converter import NDJsonConverter - - -def test_video(): - with open('tests/data/assets/ndjson/video_import.json', 'r') as file: - data = json.load(file) - - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - assert res == [data[2], data[0], data[1], data[3], data[4], data[5]] - - -def test_video_name_only(): - with open('tests/data/assets/ndjson/video_import_name_only.json', - 'r') as file: - data = json.load(file) - - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - assert res == [data[2], data[0], data[1], data[3], data[4], data[5]] - - -def test_video_classification_global_subclassifications(): - label = Label( - data=VideoData(global_key="sample-video-4.mp4",), - annotations=[ - ClassificationAnnotation( - name='radio_question_nested', - value=Radio(answer=ClassificationAnswer( - name='first_radio_question')), - ), - ClassificationAnnotation( - name='nested_checklist_question', - value=Checklist( - name='checklist', - answer=[ - ClassificationAnswer( - name='first_checklist_answer', - classifications=[ - ClassificationAnnotation( - name='sub_checklist_question', - value=Radio(answer=ClassificationAnswer( - name='first_sub_checklist_answer'))) - ]) - ])) - ]) - - expected_first_annotation = { - 'name': 'radio_question_nested', - 'answer': { - 'name': 'first_radio_question' - }, - 'dataRow': { - 'globalKey': 'sample-video-4.mp4' - } - } - - expected_second_annotation = nested_checklist_annotation_ndjson = { - "name": "nested_checklist_question", - "answer": [{ - "name": - "first_checklist_answer", - "classifications": [{ - "name": "sub_checklist_question", - "answer": { - "name": "first_sub_checklist_answer" - } - }] - }], - 'dataRow': { - 'globalKey': 'sample-video-4.mp4' - } - } - - serialized = NDJsonConverter.serialize([label]) - res = [x for x in serialized] - for annotations in res: - annotations.pop("uuid") - assert res == [expected_first_annotation, expected_second_annotation] - - deserialized = NDJsonConverter.deserialize(res) - res = next(deserialized) - annotations = res.annotations - for annotation in annotations: - annotation.extra.pop("uuid") - assert annotations == label.annotations - - -def test_video_classification_nesting_bbox(): - bbox_annotation = [ - VideoObjectAnnotation( - name="bbox_video", - keyframe=True, - frame=13, - segment_index=0, - value=Rectangle( - start=Point(x=146.0, y=98.0), # Top left - end=Point(x=382.0, y=341.0), # Bottom right - ), - classifications=[ - ClassificationAnnotation( - name='radio_question_nested', - value=Radio(answer=ClassificationAnswer( - name='first_radio_question', - classifications=[ - ClassificationAnnotation(name='sub_question_radio', - value=Checklist(answer=[ - ClassificationAnswer( - name='sub_answer') - ])) - ])), - ) - ]), - VideoObjectAnnotation( - name="bbox_video", - keyframe=True, - frame=15, - segment_index=0, - value=Rectangle( - start=Point(x=146.0, y=98.0), # Top left - end=Point(x=382.0, y=341.0), # Bottom right, - ), - classifications=[ - ClassificationAnnotation( - name='nested_checklist_question', - value=Checklist(answer=[ - ClassificationAnswer( - name='first_checklist_answer', - classifications=[ - ClassificationAnnotation( - name='sub_checklist_question', - value=Radio(answer=ClassificationAnswer( - name='first_sub_checklist_answer'))) - ]) - ])) - ]), - VideoObjectAnnotation( - name="bbox_video", - keyframe=True, - frame=19, - segment_index=0, - value=Rectangle( - start=Point(x=146.0, y=98.0), # Top left - end=Point(x=382.0, y=341.0), # Bottom right - )) - ] - expected = [{ - 'dataRow': { - 'globalKey': 'sample-video-4.mp4' - }, - 'name': - 'bbox_video', - 'classifications': [], - 'segments': [{ - 'keyframes': [{ - 'frame': - 13, - 'bbox': { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - }, - 'classifications': [{ - 'name': 'radio_question_nested', - 'answer': { - 'name': - 'first_radio_question', - 'classifications': [{ - 'name': 'sub_question_radio', - 'answer': [{ - 'name': 'sub_answer' - }] - }] - } - }] - }, { - 'frame': - 15, - 'bbox': { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - }, - 'classifications': [{ - 'name': - 'nested_checklist_question', - 'answer': [{ - 'name': - 'first_checklist_answer', - 'classifications': [{ - 'name': 'sub_checklist_question', - 'answer': { - 'name': 'first_sub_checklist_answer' - } - }] - }] - }] - }, { - 'frame': 19, - 'bbox': { - 'top': 98.0, - 'left': 146.0, - 'height': 243.0, - 'width': 236.0 - }, - 'classifications': [] - }] - }] - }] - - label = Label(data=VideoData(global_key="sample-video-4.mp4",), - annotations=bbox_annotation) - - serialized = NDJsonConverter.serialize([label]) - res = [x for x in serialized] - for annotations in res: - annotations.pop("uuid") - assert res == expected - - deserialized = NDJsonConverter.deserialize(res) - res = next(deserialized) - annotations = res.annotations - for annotation in annotations: - annotation.extra.pop("uuid") - assert annotations == label.annotations - - -def test_video_classification_point(): - bbox_annotation = [ - VideoObjectAnnotation( - name="bbox_video", - keyframe=True, - frame=13, - segment_index=0, - value=Point(x=46.0, y=8.0), - classifications=[ - ClassificationAnnotation( - name='radio_question_nested', - value=Radio(answer=ClassificationAnswer( - name='first_radio_question', - classifications=[ - ClassificationAnnotation(name='sub_question_radio', - value=Checklist(answer=[ - ClassificationAnswer( - name='sub_answer') - ])) - ])), - ) - ]), - VideoObjectAnnotation( - name="bbox_video", - keyframe=True, - frame=15, - segment_index=0, - value=Point(x=56.0, y=18.0), - classifications=[ - ClassificationAnnotation( - name='nested_checklist_question', - value=Checklist(answer=[ - ClassificationAnswer( - name='first_checklist_answer', - classifications=[ - ClassificationAnnotation( - name='sub_checklist_question', - value=Radio(answer=ClassificationAnswer( - name='first_sub_checklist_answer'))) - ]) - ])) - ]), - VideoObjectAnnotation( - name="bbox_video", - keyframe=True, - frame=19, - segment_index=0, - value=Point(x=66.0, y=28.0), - ) - ] - expected = [{ - 'dataRow': { - 'globalKey': 'sample-video-4.mp4' - }, - 'name': - 'bbox_video', - 'classifications': [], - 'segments': [{ - 'keyframes': [{ - 'frame': - 13, - 'point': { - 'x': 46.0, - 'y': 8.0, - }, - 'classifications': [{ - 'name': 'radio_question_nested', - 'answer': { - 'name': - 'first_radio_question', - 'classifications': [{ - 'name': 'sub_question_radio', - 'answer': [{ - 'name': 'sub_answer' - }] - }] - } - }] - }, { - 'frame': - 15, - 'point': { - 'x': 56.0, - 'y': 18.0, - }, - 'classifications': [{ - 'name': - 'nested_checklist_question', - 'answer': [{ - 'name': - 'first_checklist_answer', - 'classifications': [{ - 'name': 'sub_checklist_question', - 'answer': { - 'name': 'first_sub_checklist_answer' - } - }] - }] - }] - }, { - 'frame': 19, - 'point': { - 'x': 66.0, - 'y': 28.0, - }, - 'classifications': [] - }] - }] - }] - - label = Label(data=VideoData(global_key="sample-video-4.mp4",), - annotations=bbox_annotation) - - serialized = NDJsonConverter.serialize([label]) - res = [x for x in serialized] - for annotations in res: - annotations.pop("uuid") - assert res == expected - - deserialized = NDJsonConverter.deserialize(res) - res = next(deserialized) - annotations = res.annotations - for annotation in annotations: - annotation.extra.pop("uuid") - assert annotations == label.annotations - - -def test_video_classification_frameline(): - bbox_annotation = [ - VideoObjectAnnotation( - name="bbox_video", - keyframe=True, - frame=13, - segment_index=0, - value=Line( - points=[Point(x=8, y=10), Point(x=10, y=9)]), - classifications=[ - ClassificationAnnotation( - name='radio_question_nested', - value=Radio(answer=ClassificationAnswer( - name='first_radio_question', - classifications=[ - ClassificationAnnotation(name='sub_question_radio', - value=Checklist(answer=[ - ClassificationAnswer( - name='sub_answer') - ])) - ])), - ) - ]), - VideoObjectAnnotation( - name="bbox_video", - keyframe=True, - frame=15, - segment_index=0, - value=Line( - points=[Point(x=18, y=20), Point(x=20, y=19)]), - classifications=[ - ClassificationAnnotation( - name='nested_checklist_question', - value=Checklist(answer=[ - ClassificationAnswer( - name='first_checklist_answer', - classifications=[ - ClassificationAnnotation( - name='sub_checklist_question', - value=Radio(answer=ClassificationAnswer( - name='first_sub_checklist_answer'))) - ]) - ])) - ]), - VideoObjectAnnotation( - name="bbox_video", - keyframe=True, - frame=19, - segment_index=0, - value=Line( - points=[Point(x=28, y=30), Point(x=30, y=29)]), - ) - ] - expected = [{ - 'dataRow': { - 'globalKey': 'sample-video-4.mp4' - }, - 'name': - 'bbox_video', - 'classifications': [], - 'segments': [{ - 'keyframes': [{ - 'frame': - 13, - 'line': [{ - 'x': 8.0, - 'y': 10.0, - }, { - 'x': 10.0, - 'y': 9.0, - }], - 'classifications': [{ - 'name': 'radio_question_nested', - 'answer': { - 'name': - 'first_radio_question', - 'classifications': [{ - 'name': 'sub_question_radio', - 'answer': [{ - 'name': 'sub_answer' - }] - }] - } - }] - }, { - 'frame': - 15, - 'line': [{ - 'x': 18.0, - 'y': 20.0, - }, { - 'x': 20.0, - 'y': 19.0, - }], - 'classifications': [{ - 'name': - 'nested_checklist_question', - 'answer': [{ - 'name': - 'first_checklist_answer', - 'classifications': [{ - 'name': 'sub_checklist_question', - 'answer': { - 'name': 'first_sub_checklist_answer' - } - }] - }] - }] - }, { - 'frame': 19, - 'line': [{ - 'x': 28.0, - 'y': 30.0, - }, { - 'x': 30.0, - 'y': 29.0, - }], - 'classifications': [] - }] - }] - }] - - label = Label(data=VideoData(global_key="sample-video-4.mp4",), - annotations=bbox_annotation) - - serialized = NDJsonConverter.serialize([label]) - res = [x for x in serialized] - for annotations in res: - annotations.pop("uuid") - assert res == expected - - deserialized = NDJsonConverter.deserialize(res) - res = next(deserialized) - annotations = res.annotations - for annotation in annotations: - annotation.extra.pop("uuid") - assert annotations == label.annotations ----- -tests/data/serialization/ndjson/test_nested.py -import json - -from labelbox.data.serialization.ndjson.converter import NDJsonConverter - - -def test_nested(): - with open('tests/data/assets/ndjson/nested_import.json', 'r') as file: - data = json.load(file) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - assert res == data - - -def test_nested_name_only(): - with open('tests/data/assets/ndjson/nested_import_name_only.json', - 'r') as file: - data = json.load(file) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - assert res == data - ----- -tests/data/serialization/ndjson/test_global_key.py -import json -import pytest - -from labelbox.data.serialization.ndjson.classification import NDRadio - -from labelbox.data.serialization.ndjson.converter import NDJsonConverter -from labelbox.data.serialization.ndjson.objects import NDLine - - -def round_dict(data): - if isinstance(data, dict): - for key in data: - if isinstance(data[key], float): - data[key] = int(data[key]) - elif isinstance(data[key], dict): - data[key] = round_dict(data[key]) - elif isinstance(data[key], (list, tuple)): - data[key] = [round_dict(r) for r in data[key]] - - return data - - -@pytest.mark.parametrize('filename', [ - 'tests/data/assets/ndjson/classification_import_global_key.json', - 'tests/data/assets/ndjson/metric_import_global_key.json', - 'tests/data/assets/ndjson/polyline_import_global_key.json', - 'tests/data/assets/ndjson/text_entity_import_global_key.json', - 'tests/data/assets/ndjson/conversation_entity_import_global_key.json', -]) -def test_many_types(filename: str): - with open(filename, 'r') as f: - data = json.load(f) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - assert res == data - f.close() - - -def test_image(): - with open('tests/data/assets/ndjson/image_import_global_key.json', - 'r') as f: - data = json.load(f) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - for r in res: - r.pop('classifications', None) - assert [round_dict(x) for x in res] == [round_dict(x) for x in data] - f.close() - - -def test_pdf(): - with open('tests/data/assets/ndjson/pdf_import_global_key.json', 'r') as f: - data = json.load(f) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - assert [round_dict(x) for x in res] == [round_dict(x) for x in data] - f.close() - - -def test_video(): - with open('tests/data/assets/ndjson/video_import_global_key.json', - 'r') as f: - data = json.load(f) - - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - assert res == [data[2], data[0], data[1], data[3], data[4], data[5]] - f.close() - ----- -tests/data/serialization/ndjson/test_conversation.py -import json - -import pytest -import labelbox.types as lb_types -from labelbox.data.serialization.ndjson.converter import NDJsonConverter - -radio_ndjson = [{ - 'dataRow': { - 'globalKey': 'my_global_key' - }, - 'name': 'radio', - 'answer': { - 'name': 'first_radio_answer' - }, - 'messageId': '0' -}] - -radio_label = [ - lb_types.Label( - data=lb_types.ConversationData(global_key='my_global_key'), - annotations=[ - lb_types.ClassificationAnnotation( - name='radio', - value=lb_types.Radio(answer=lb_types.ClassificationAnswer( - name="first_radio_answer")), - message_id="0") - ]) -] - -checklist_ndjson = [{ - 'dataRow': { - 'globalKey': 'my_global_key' - }, - 'name': 'checklist', - 'answer': [ - { - 'name': 'first_checklist_answer' - }, - { - 'name': 'second_checklist_answer' - }, - ], - 'messageId': '2' -}] - -checklist_label = [ - lb_types.Label(data=lb_types.ConversationData(global_key='my_global_key'), - annotations=[ - lb_types.ClassificationAnnotation( - name='checklist', - message_id="2", - value=lb_types.Checklist(answer=[ - lb_types.ClassificationAnswer( - name="first_checklist_answer"), - lb_types.ClassificationAnswer( - name="second_checklist_answer") - ])) - ]) -] - -free_text_ndjson = [{ - 'dataRow': { - 'globalKey': 'my_global_key' - }, - 'name': 'free_text', - 'answer': 'sample text', - 'messageId': '0' -}] -free_text_label = [ - lb_types.Label(data=lb_types.ConversationData(global_key='my_global_key'), - annotations=[ - lb_types.ClassificationAnnotation( - name='free_text', - message_id="0", - value=lb_types.Text(answer="sample text")) - ]) -] - - -@pytest.mark.parametrize( - "label, ndjson", - [[radio_label, radio_ndjson], [checklist_label, checklist_ndjson], - [free_text_label, free_text_ndjson]]) -def test_message_based_radio_classification(label, ndjson): - serialized_label = list(NDJsonConverter().serialize(label)) - serialized_label[0].pop('uuid') - assert serialized_label == ndjson - - deserialized_label = list(NDJsonConverter().deserialize(ndjson)) - deserialized_label[0].annotations[0].extra.pop('uuid') - assert deserialized_label[0].annotations == label[0].annotations - - -@pytest.mark.parametrize("filename", [ - "tests/data/assets/ndjson/conversation_entity_import.json", - "tests/data/assets/ndjson/conversation_entity_without_confidence_import.json" -]) -def test_conversation_entity_import(filename: str): - with open(filename, 'r') as file: - data = json.load(file) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - assert res == data - ----- -tests/data/serialization/ndjson/test_text.py -from labelbox.data.annotation_types.annotation import ClassificationAnnotation -from labelbox.data.annotation_types.classification.classification import ClassificationAnswer, Radio, Text -from labelbox.data.annotation_types.data.text import TextData -from labelbox.data.annotation_types.label import Label - -from labelbox.data.serialization.ndjson.converter import NDJsonConverter - - -def test_serialization(): - label = Label(uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( - uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", - ), - annotations=[ - ClassificationAnnotation( - name="radio_question_geo", - confidence=0.5, - value=Text(answer="first_radio_answer")) - ]) - - serialized = NDJsonConverter.serialize([label]) - res = next(serialized) - assert 'confidence' not in res # because confidence needs to be set on the annotation itself - assert res['name'] == "radio_question_geo" - assert res['answer'] == "first_radio_answer" - assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d" - - deserialized = NDJsonConverter.deserialize([res]) - res = next(deserialized) - annotation = res.annotations[0] - - annotation_value = annotation.value - assert type(annotation_value) is Text - assert annotation_value.answer == "first_radio_answer" - ----- -tests/data/serialization/ndjson/test_export_video_objects.py -from labelbox.data.annotation_types import Label, VideoObjectAnnotation -from labelbox.data.serialization.ndjson.converter import NDJsonConverter -from labelbox.data.annotation_types.geometry import Rectangle, Point -from labelbox.data.annotation_types import VideoData - - -def video_bbox_label(): - return Label( - uid='cl1z52xwh00050fhcmfgczqvn', - data=VideoData( - uid="cklr9mr4m5iao0rb6cvxu4qbn", - file_path=None, - frames=None, - url= - "https://storage.labelbox.com/ckcz6bubudyfi0855o1dt1g9s%2F26403a22-604a-a38c-eeff-c2ed481fb40a-cat.mp4?Expires=1651677421050&KeyName=labelbox-assets-key-3&Signature=vF7gMyfHzgZdfbB8BHgd88Ws-Ms" - ), - annotations=[ - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=46.0), - end=Point(extra={}, - x=454.0, - y=295.0)), - classifications=[], - frame=1, - keyframe=True), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=42.5), - end=Point(extra={}, - x=427.25, - y=308.25)), - classifications=[], - frame=2, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=39.0), - end=Point(extra={}, - x=400.5, - y=321.5)), - classifications=[], - frame=3, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=35.5), - end=Point(extra={}, - x=373.75, - y=334.75)), - classifications=[], - frame=4, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=32.0), - end=Point(extra={}, - x=347.0, - y=348.0)), - classifications=[], - frame=5, - keyframe=True), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=132.0), - end=Point(extra={}, - x=283.0, - y=348.0)), - classifications=[], - frame=9, - keyframe=True), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=122.333), - end=Point(extra={}, - x=295.5, - y=348.0)), - classifications=[], - frame=10, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=112.667), - end=Point(extra={}, - x=308.0, - y=348.0)), - classifications=[], - frame=11, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=103.0), - end=Point(extra={}, - x=320.5, - y=348.0)), - classifications=[], - frame=12, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=93.333), - end=Point(extra={}, - x=333.0, - y=348.0)), - classifications=[], - frame=13, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=83.667), - end=Point(extra={}, - x=345.5, - y=348.0)), - classifications=[], - frame=14, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=74.0), - end=Point(extra={}, - x=358.0, - y=348.0)), - classifications=[], - frame=15, - keyframe=True), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=66.833), - end=Point(extra={}, - x=387.333, - y=348.0)), - classifications=[], - frame=16, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=59.667), - end=Point(extra={}, - x=416.667, - y=348.0)), - classifications=[], - frame=17, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=52.5), - end=Point(extra={}, - x=446.0, - y=348.0)), - classifications=[], - frame=18, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=45.333), - end=Point(extra={}, - x=475.333, - y=348.0)), - classifications=[], - frame=19, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=38.167), - end=Point(extra={}, - x=504.667, - y=348.0)), - classifications=[], - frame=20, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=31.0), - end=Point(extra={}, - x=534.0, - y=348.0)), - classifications=[], - frame=21, - keyframe=True), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=29.5), - end=Point(extra={}, - x=543.0, - y=348.0)), - classifications=[], - frame=22, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=28.0), - end=Point(extra={}, - x=552.0, - y=348.0)), - classifications=[], - frame=23, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=26.5), - end=Point(extra={}, - x=561.0, - y=348.0)), - classifications=[], - frame=24, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=25.0), - end=Point(extra={}, - x=570.0, - y=348.0)), - classifications=[], - frame=25, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=23.5), - end=Point(extra={}, - x=579.0, - y=348.0)), - classifications=[], - frame=26, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=22.0), - end=Point(extra={}, - x=588.0, - y=348.0)), - classifications=[], - frame=27, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=20.5), - end=Point(extra={}, - x=597.0, - y=348.0)), - classifications=[], - frame=28, - keyframe=False), - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=19.0), - end=Point(extra={}, - x=606.0, - y=348.0)), - classifications=[], - frame=29, - keyframe=True) - ], - extra={ - 'Created By': - 'jtso@labelbox.com', - 'Project Name': - 'Pictor Video', - 'Created At': - '2022-04-14T15:11:19.000Z', - 'Updated At': - '2022-04-14T15:11:21.064Z', - 'Seconds to Label': - 0.0, - 'Agreement': - -1.0, - 'Benchmark Agreement': - -1.0, - 'Benchmark ID': - None, - 'Dataset Name': - 'cat', - 'Reviews': [], - 'View Label': - 'https://editor.labelbox.com?project=ckz38nsfd0lzq109bhq73est1&label=cl1z52xwh00050fhcmfgczqvn', - 'Has Open Issues': - 0.0, - 'Skipped': - False, - 'media_type': - 'video', - 'Data Split': - None - }) - - -def video_serialized_bbox_label(): - return { - 'uuid': - 'b24e672b-8f79-4d96-bf5e-b552ca0820d5', - 'dataRow': { - 'id': 'cklr9mr4m5iao0rb6cvxu4qbn' - }, - 'schemaId': - 'ckz38ofop0mci0z9i9w3aa9o4', - 'name': - 'bbox toy', - 'classifications': [], - 'segments': [{ - 'keyframes': [{ - 'frame': 1, - 'bbox': { - 'top': 46.0, - 'left': 70.0, - 'height': 249.0, - 'width': 384.0 - }, - 'classifications': [] - }, { - 'frame': 5, - 'bbox': { - 'top': 32.0, - 'left': 70.0, - 'height': 316.0, - 'width': 277.0 - }, - 'classifications': [] - }] - }, { - 'keyframes': [{ - 'frame': 9, - 'bbox': { - 'top': 132.0, - 'left': 70.0, - 'height': 216.0, - 'width': 213.0 - }, - 'classifications': [] - }, { - 'frame': 15, - 'bbox': { - 'top': 74.0, - 'left': 70.0, - 'height': 274.0, - 'width': 288.0 - }, - 'classifications': [] - }, { - 'frame': 21, - 'bbox': { - 'top': 31.0, - 'left': 70.0, - 'height': 317.0, - 'width': 464.0 - }, - 'classifications': [] - }, { - 'frame': 29, - 'bbox': { - 'top': 19.0, - 'left': 70.0, - 'height': 329.0, - 'width': 536.0 - }, - 'classifications': [] - }] - }] - } - - -def test_serialize_video_objects(): - label = video_bbox_label() - serialized_labels = NDJsonConverter.serialize([label]) - label = next(serialized_labels) - - manual_label = video_serialized_bbox_label() - - for key in label.keys(): - # ignore uuid because we randomize if there was none - if key != "uuid": - assert label[key] == manual_label[key] - - assert len(label['segments']) == 2 - assert len(label['segments'][0]['keyframes']) == 2 - assert len(label['segments'][1]['keyframes']) == 4 - - # #converts back only the keyframes. should be the sum of all prev segments - deserialized_labels = NDJsonConverter.deserialize([label]) - label = next(deserialized_labels) - assert len(label.annotations) == 6 - - -def test_confidence_is_ignored(): - label = video_bbox_label() - serialized_labels = NDJsonConverter.serialize([label]) - label = next(serialized_labels) - label["confidence"] = 0.453 - label['segments'][0]["confidence"] = 0.453 - - deserialized_labels = NDJsonConverter.deserialize([label]) - label = next(deserialized_labels) - for annotation in label.annotations: - assert annotation.confidence is None - ----- -tests/data/serialization/ndjson/test_text_entity.py -import json - -import pytest - -from labelbox.data.serialization.ndjson.converter import NDJsonConverter - - -@pytest.mark.parametrize("filename", [ - "tests/data/assets/ndjson/text_entity_import.json", - "tests/data/assets/ndjson/text_entity_without_confidence_import.json" -]) -def test_text_entity_import(filename: str): - with open(filename, 'r') as file: - data = json.load(file) - - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - assert res == data - ----- -tests/data/serialization/ndjson/test_checklist.py -from labelbox.data.annotation_types.annotation import ClassificationAnnotation -from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnswer, Radio -from labelbox.data.annotation_types.data.text import TextData -from labelbox.data.annotation_types.label import Label - -from labelbox.data.serialization.ndjson.converter import NDJsonConverter - - -def test_serialization_min(): - label = Label( - uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( - uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", - ), - annotations=[ - ClassificationAnnotation( - name="checkbox_question_geo", - value=Checklist( - answer=[ClassificationAnswer(name="first_answer")]), - ) - ]) - - expected = { - 'name': 'checkbox_question_geo', - 'dataRow': { - 'id': 'bkj7z2q0b0000jx6x0q2q7q0d' - }, - 'answer': [{ - 'name': 'first_answer' - }] - } - serialized = NDJsonConverter.serialize([label]) - res = next(serialized) - res.pop("uuid") - assert res == expected - - deserialized = NDJsonConverter.deserialize([res]) - res = next(deserialized) - res.annotations[0].extra.pop("uuid") - assert res.annotations == label.annotations - - -def test_serialization_with_classification(): - label = Label( - uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( - uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", - ), - annotations=[ - ClassificationAnnotation( - name="checkbox_question_geo", - confidence=0.5, - value=Checklist(answer=[ - ClassificationAnswer( - name="first_answer", - confidence=0.1, - classifications=[ - ClassificationAnnotation( - name="sub_radio_question", - value=Radio(answer=ClassificationAnswer( - name="first_sub_radio_answer", - confidence=0.31))), - ClassificationAnnotation( - name="sub_chck_question", - value=Checklist(answer=[ - ClassificationAnswer( - name="second_subchk_answer", - confidence=0.41), - ClassificationAnswer( - name="third_subchk_answer", - confidence=0.42), - ],)) - ]), - ])) - ]) - - expected = { - 'confidence': - 0.5, - 'name': - 'checkbox_question_geo', - 'dataRow': { - 'id': 'bkj7z2q0b0000jx6x0q2q7q0d' - }, - 'answer': [{ - 'confidence': - 0.1, - 'name': - 'first_answer', - 'classifications': [{ - 'name': 'sub_radio_question', - 'answer': { - 'confidence': 0.31, - 'name': 'first_sub_radio_answer', - } - }, { - 'name': - 'sub_chck_question', - 'answer': [{ - 'confidence': 0.41, - 'name': 'second_subchk_answer', - }, { - 'confidence': 0.42, - 'name': 'third_subchk_answer', - }] - }] - }] - } - - serialized = NDJsonConverter.serialize([label]) - res = next(serialized) - - res.pop("uuid") - assert res == expected - - deserialized = NDJsonConverter.deserialize([res]) - res = next(deserialized) - res.annotations[0].extra.pop("uuid") - assert res.annotations == label.annotations - - -def test_serialization_with_classification_double_nested(): - label = Label( - uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( - uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", - ), - annotations=[ - ClassificationAnnotation( - name="checkbox_question_geo", - confidence=0.5, - value=Checklist(answer=[ - ClassificationAnswer( - name="first_answer", - confidence=0.1, - classifications=[ - ClassificationAnnotation( - name="sub_radio_question", - value=Radio(answer=ClassificationAnswer( - name="first_sub_radio_answer", - confidence=0.31, - classifications=[ - ClassificationAnnotation( - name="sub_chck_question", - value=Checklist(answer=[ - ClassificationAnswer( - name="second_subchk_answer", - confidence=0.41), - ClassificationAnswer( - name="third_subchk_answer", - confidence=0.42), - ],)) - ]))), - ]), - ])) - ]) - - expected = { - 'confidence': - 0.5, - 'name': - 'checkbox_question_geo', - 'dataRow': { - 'id': 'bkj7z2q0b0000jx6x0q2q7q0d' - }, - 'answer': [{ - 'confidence': - 0.1, - 'name': - 'first_answer', - 'classifications': [{ - 'name': 'sub_radio_question', - 'answer': { - 'confidence': - 0.31, - 'name': - 'first_sub_radio_answer', - 'classifications': [{ - 'name': - 'sub_chck_question', - 'answer': [{ - 'confidence': 0.41, - 'name': 'second_subchk_answer', - }, { - 'confidence': 0.42, - 'name': 'third_subchk_answer', - }] - }] - } - }] - }] - } - serialized = NDJsonConverter.serialize([label]) - res = next(serialized) - - res.pop("uuid") - assert res == expected - - deserialized = NDJsonConverter.deserialize([res]) - res = next(deserialized) - res.annotations[0].extra.pop("uuid") - assert res.annotations == label.annotations - - -def test_serialization_with_classification_double_nested_2(): - label = Label( - uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( - uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", - ), - annotations=[ - ClassificationAnnotation( - name="sub_radio_question", - value=Radio(answer=ClassificationAnswer( - name="first_sub_radio_answer", - confidence=0.31, - classifications=[ - ClassificationAnnotation( - name="sub_chck_question", - value=Checklist(answer=[ - ClassificationAnswer( - name="second_subchk_answer", - confidence=0.41, - classifications=[ - ClassificationAnnotation( - name="checkbox_question_geo", - value=Checklist(answer=[ - ClassificationAnswer( - name="first_answer", - confidence=0.1, - classifications=[]), - ])) - ]), - ClassificationAnswer(name="third_subchk_answer", - confidence=0.42), - ])) - ]))), - ]) - - expected = { - 'name': 'sub_radio_question', - 'answer': { - 'confidence': - 0.31, - 'name': - 'first_sub_radio_answer', - 'classifications': [{ - 'name': - 'sub_chck_question', - 'answer': [{ - 'confidence': - 0.41, - 'name': - 'second_subchk_answer', - 'classifications': [{ - 'name': 'checkbox_question_geo', - 'answer': [{ - 'confidence': 0.1, - 'name': 'first_answer', - }] - }] - }, { - 'confidence': 0.42, - 'name': 'third_subchk_answer', - }] - }] - }, - 'dataRow': { - 'id': 'bkj7z2q0b0000jx6x0q2q7q0d' - } - } - - serialized = NDJsonConverter.serialize([label]) - res = next(serialized) - res.pop("uuid") - assert res == expected - - deserialized = NDJsonConverter.deserialize([res]) - res = next(deserialized) - res.annotations[0].extra.pop("uuid") - assert res.annotations == label.annotations - ----- -tests/data/serialization/ndjson/test_document.py -import json -import labelbox.types as lb_types -from labelbox.data.serialization.ndjson.converter import NDJsonConverter - -bbox_annotation = lb_types.ObjectAnnotation( - name="bounding_box", # must match your ontology feature's name - value=lb_types.DocumentRectangle( - start=lb_types.Point(x=42.799, y=86.498), # Top left - end=lb_types.Point(x=141.911, y=303.195), # Bottom right - page=1, - unit=lb_types.RectangleUnit.POINTS)) -bbox_labels = [ - lb_types.Label(data=lb_types.DocumentData(global_key='test-global-key'), - annotations=[bbox_annotation]) -] -bbox_ndjson = [{ - 'bbox': { - 'height': 216.697, - 'left': 42.799, - 'top': 86.498, - 'width': 99.112, - }, - 'classifications': [], - 'dataRow': { - 'globalKey': 'test-global-key' - }, - 'name': 'bounding_box', - 'page': 1, - 'unit': 'POINTS' -}] - - -def round_dict(data): - if isinstance(data, dict): - for key in data: - if isinstance(data[key], (int, float)): - data[key] = int(data[key]) - elif isinstance(data[key], dict): - data[key] = round_dict(data[key]) - elif isinstance(data[key], (list, tuple)): - data[key] = [round_dict(r) for r in data[key]] - - return data - - -def test_pdf(): - """ - Tests a pdf file with bbox annotations only - """ - with open('tests/data/assets/ndjson/pdf_import.json', 'r') as f: - data = json.load(f) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - assert [round_dict(x) for x in res] == [round_dict(x) for x in data] - f.close() - - -def test_pdf_with_name_only(): - """ - Tests a pdf file with bbox annotations only - """ - with open('tests/data/assets/ndjson/pdf_import_name_only.json', 'r') as f: - data = json.load(f) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - assert [round_dict(x) for x in res] == [round_dict(x) for x in data] - f.close() - - -def test_pdf_bbox_serialize(): - serialized = list(NDJsonConverter.serialize(bbox_labels)) - serialized[0].pop('uuid') - assert serialized == bbox_ndjson - - -def test_pdf_bbox_deserialize(): - deserialized = list(NDJsonConverter.deserialize(bbox_ndjson)) - deserialized[0].annotations[0].extra = {} - assert deserialized[0].annotations == bbox_labels[0].annotations - ----- -tests/data/serialization/ndjson/test_relationship.py -import json -from uuid import uuid4 - -import pytest - -from labelbox.data.serialization.ndjson.converter import NDJsonConverter - - -def test_relationship(): - with open("tests/data/assets/ndjson/relationship_import.json", "r") as file: - data = json.load(file) - - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - assert len(res) == len(data) - - res_relationship_annotation, res_relationship_second_annotation = [ - annot for annot in res if "relationship" in annot - ] - res_source_and_target = [ - annot for annot in res if "relationship" not in annot - ] - assert res_relationship_annotation - - assert res_relationship_annotation["relationship"]["source"] in [ - annot["uuid"] for annot in res_source_and_target - ] - assert res_relationship_annotation["relationship"]["target"] in [ - annot["uuid"] for annot in res_source_and_target - ] - - assert res_relationship_second_annotation - assert res_relationship_second_annotation["relationship"][ - "source"] != res_relationship_annotation["relationship"]["source"] - assert res_relationship_second_annotation["relationship"][ - "target"] != res_relationship_annotation["relationship"]["target"] - assert res_relationship_second_annotation["relationship"]["source"] in [ - annot["uuid"] for annot in res_source_and_target - ] - assert res_relationship_second_annotation["relationship"]["target"] in [ - annot["uuid"] for annot in res_source_and_target - ] - - -def test_relationship_nonexistent_object(): - with open("tests/data/assets/ndjson/relationship_import.json", "r") as file: - data = json.load(file) - - relationship_annotation = data[2] - source_uuid = relationship_annotation["relationship"]["source"] - target_uuid = str(uuid4()) - relationship_annotation["relationship"]["target"] = target_uuid - error_msg = f"Relationship object refers to nonexistent object with UUID '{source_uuid}' and/or '{target_uuid}'" - - with pytest.raises(ValueError, match=error_msg): - list(NDJsonConverter.deserialize(data)) - - -def test_relationship_duplicate_uuids(): - with open("tests/data/assets/ndjson/relationship_import.json", "r") as file: - data = json.load(file) - - source, target = data[0], data[1] - target["uuid"] = source["uuid"] - error_msg = f"UUID '{source['uuid']}' is not unique" - - with pytest.raises(AssertionError, match=error_msg): - list(NDJsonConverter.deserialize(data)) - ----- -tests/data/serialization/ndjson/test_classification.py -import json - -from labelbox.data.serialization.ndjson.converter import NDJsonConverter - - -def test_classification(): - with open('tests/data/assets/ndjson/classification_import.json', - 'r') as file: - data = json.load(file) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - assert res == data - - -def test_classification_with_name(): - with open('tests/data/assets/ndjson/classification_import_name_only.json', - 'r') as file: - data = json.load(file) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - assert res == data - ----- -tests/data/serialization/ndjson/test_polyline.py -import json -import pytest -from labelbox.data.serialization.ndjson.converter import NDJsonConverter - - -@pytest.mark.parametrize("filename", [ - "tests/data/assets/ndjson/polyline_without_confidence_import.json", - "tests/data/assets/ndjson/polyline_import.json" -]) -def test_polyline_import(filename: str): - with open(filename, 'r') as file: - data = json.load(file) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - assert res == data - ----- -tests/data/serialization/coco/test_coco.py -import json -from pathlib import Path - -from labelbox.data.serialization.coco import COCOConverter - -COCO_ASSETS_DIR = "tests/data/assets/coco" - - -def run_instances(tmpdir): - instance_json = json.load(open(Path(COCO_ASSETS_DIR, 'instances.json'))) - res = COCOConverter.deserialize_instances(instance_json, - Path(COCO_ASSETS_DIR, 'images')) - back = COCOConverter.serialize_instances( - res, - Path(tmpdir), - ) - - -def test_rle_objects(tmpdir): - rle_json = json.load(open(Path(COCO_ASSETS_DIR, 'rle.json'))) - res = COCOConverter.deserialize_instances(rle_json, - Path(COCO_ASSETS_DIR, 'images')) - back = COCOConverter.serialize_instances(res, tmpdir) - - -def test_panoptic(tmpdir): - panoptic_json = json.load(open(Path(COCO_ASSETS_DIR, 'panoptic.json'))) - image_dir, mask_dir = [ - Path(COCO_ASSETS_DIR, dir_name) for dir_name in ['images', 'masks'] - ] - res = COCOConverter.deserialize_panoptic(panoptic_json, image_dir, mask_dir) - back = COCOConverter.serialize_panoptic(res, - Path(f'/{tmpdir}/images_panoptic'), - Path(f'/{tmpdir}/masks_panoptic')) - ----- -tests/data/serialization/labelbox_v1/test_unknown_media.py -import json - -import pytest - -from labelbox.data.serialization.labelbox_v1.converter import LBV1Converter - - -def test_image(): - file_path = 'tests/data/assets/labelbox_v1/unkown_media_type_export.json' - with open(file_path, 'r') as file: - payload = json.load(file) - - collection = list(LBV1Converter.deserialize(payload)) - # One of the data rows is broken. - assert len(collection) != len(payload) - - for row in payload: - row['media_type'] = 'image' - row['Global Key'] = None - - collection = LBV1Converter.deserialize(payload) - for idx, serialized in enumerate(LBV1Converter.serialize(collection)): - assert serialized.keys() == payload[idx].keys() - for key in serialized: - if key != 'Label': - assert serialized[key] == payload[idx][key] - elif key == 'Label': - for annotation_a, annotation_b in zip( - serialized[key]['objects'], - payload[idx][key]['objects']): - if not len(annotation_a['classifications']): - # We don't add a classification key to the payload if there is no classifications. - annotation_a.pop('classifications') - annotation_b['page'] = None - annotation_b['unit'] = None - - if isinstance(annotation_b.get('classifications'), - list) and len( - annotation_b['classifications']): - if isinstance(annotation_b['classifications'][0], list): - annotation_b['classifications'] = annotation_b[ - 'classifications'][0] - - assert annotation_a == annotation_b - - -# After check the nd serializer on this shit.. It should work for almost everything (except the other horse shit..) - ----- -tests/data/serialization/labelbox_v1/test_image.py -import json - -import pytest - -from labelbox.data.serialization.labelbox_v1.converter import LBV1Converter - - -@pytest.mark.parametrize("file_path", [ - 'tests/data/assets/labelbox_v1/highly_nested_image.json', - 'tests/data/assets/labelbox_v1/image_export.json' -]) -#TODO: some checklists from the export come in as [checklist ans: []] -# while others are checklist ans: []... when we can figure out why we sometimes -# have extra brackets, we can look into testing nested checklist answers -# and ensuring the export's output matches deserialized/serialized output -def test_image(file_path): - with open(file_path, 'r') as file: - payload = json.load(file) - - collection = LBV1Converter.deserialize([payload]) - serialized = next(LBV1Converter.serialize(collection)) - - # We are storing the media types now. - payload['media_type'] = 'image' - payload['Global Key'] = None - - assert serialized.keys() == payload.keys() - - for key in serialized: - if key != 'Label': - assert serialized[key] == payload[key] - elif key == 'Label': - for annotation_a, annotation_b in zip(serialized[key]['objects'], - payload[key]['objects']): - annotation_b['page'] = None - annotation_b['unit'] = None - if not len(annotation_a['classifications']): - # We don't add a classification key to the payload if there is no classifications. - annotation_a.pop('classifications') - - if isinstance(annotation_b.get('classifications'), - list) and len(annotation_b['classifications']): - if isinstance(annotation_b['classifications'][0], list): - annotation_b['classifications'] = annotation_b[ - 'classifications'][0] - assert annotation_a == annotation_b - - -# After check the nd serializer on this shit.. It should work for almost everything (except the other horse shit..) - ----- -tests/data/serialization/labelbox_v1/test_video.py -import json - -from labelbox.data.serialization.labelbox_v1.converter import LBV1Converter - - -def round_dict(data): - for key in data: - if isinstance(data[key], float): - data[key] = int(data[key]) - elif isinstance(data[key], dict): - data[key] = round_dict(data[key]) - return data - - -def test_video(): - payload = json.load( - open('tests/data/assets/labelbox_v1/video_export.json', 'r')) - collection = LBV1Converter.deserialize([payload]) - serialized = next(LBV1Converter.serialize(collection)) - payload['media_type'] = 'video' - payload['Global Key'] = None - assert serialized.keys() == payload.keys() - for key in serialized: - if key != 'Label': - assert serialized[key] == payload[key] - elif key == 'Label': - for annotation_a, annotation_b in zip(serialized[key], - payload[key]): - assert annotation_a['frameNumber'] == annotation_b[ - 'frameNumber'] - assert annotation_a['classifications'] == annotation_b[ - 'classifications'] - - for obj_a, obj_b in zip(annotation_a['objects'], - annotation_b['objects']): - obj_b['page'] = None - obj_b['unit'] = None - obj_a = round_dict(obj_a) - obj_b = round_dict(obj_b) - assert obj_a == obj_b - ----- -tests/data/serialization/labelbox_v1/test_text.py -import json - -from labelbox.data.serialization.labelbox_v1.converter import LBV1Converter - - -def test_text(): - with open('tests/data/assets/labelbox_v1/text_export.json', 'r') as file: - payload = json.load(file) - collection = LBV1Converter.deserialize([payload]) - serialized = next(LBV1Converter.serialize(collection)) - - payload['media_type'] = 'text' - payload['Global Key'] = None - - assert serialized.keys() == payload.keys() - for key in serialized: - if key != 'Label': - assert serialized[key] == payload[key] - elif key == 'Label': - for annotation_a, annotation_b in zip(serialized[key]['objects'], - payload[key]['objects']): - annotation_b['page'] = None - annotation_b['unit'] = None - if not len(annotation_a['classifications']): - # We don't add a classification key to the payload if there is no classifications. - annotation_a.pop('classifications') - assert annotation_a == annotation_b - ----- -tests/data/serialization/labelbox_v1/test_tiled_image.py -import json - -import pytest -from labelbox.data.annotation_types.geometry.polygon import Polygon -from labelbox.data.annotation_types.geometry.line import Line -from labelbox.data.annotation_types.geometry.point import Point -from labelbox.data.annotation_types.geometry.rectangle import Rectangle - -from labelbox.data.serialization.labelbox_v1.converter import LBV1Converter -from labelbox.schema.bulk_import_request import Bbox - - -@pytest.mark.parametrize( - "file_path", ['tests/data/assets/labelbox_v1/tiled_image_export.json']) -def test_image(file_path): - """Tests against both Simple and non-Simple tiled image export data. - index-0 is non-Simple, index-1 is Simple - """ - with open(file_path, 'r') as f: - payload = json.load(f) - - collection = LBV1Converter.deserialize(payload) - collection_as_list = list(collection) - - assert len(collection_as_list) == 2 - - non_simple_annotations = collection_as_list[0].annotations - assert len(non_simple_annotations) == 6 - expected_shapes = [Polygon, Point, Point, Point, Line, Rectangle] - for idx in range(len(non_simple_annotations)): - assert isinstance(non_simple_annotations[idx].value, - expected_shapes[idx]) - assert non_simple_annotations[-1].value.start.x == -99.36567524971268 - assert non_simple_annotations[-1].value.start.y == 19.34717117508651 - assert non_simple_annotations[-1].value.end.x == -99.3649886680726 - assert non_simple_annotations[-1].value.end.y == 19.41999425190506 - - simple_annotations = collection_as_list[1].annotations - assert len(simple_annotations) == 8 - expected_shapes = [ - Polygon, Point, Point, Point, Point, Point, Line, Rectangle - ] - for idx in range(len(simple_annotations)): - assert isinstance(simple_annotations[idx].value, - expected_shapes[idx]) - ----- -tests/data/serialization/labelbox_v1/test_document.py -import json -from typing import Dict, Any - -from labelbox.data.serialization.labelbox_v1.converter import LBV1Converter - -IGNORE_KEYS = [ - "Data Split", "media_type", "DataRow Metadata", "Media Attributes" -] - - -def round_dict(data: Dict[str, Any]) -> Dict[str, Any]: - for key in data: - if isinstance(data[key], float): - data[key] = int(data[key]) - elif isinstance(data[key], dict): - data[key] = round_dict(data[key]) - return data - - -def test_pdf(): - """ - Tests an export from a pdf document with only bounding boxes - """ - payload = json.load( - open('tests/data/assets/labelbox_v1/pdf_export.json', 'r')) - collection = LBV1Converter.deserialize(payload) - serialized = next(LBV1Converter.serialize(collection)) - - payload = payload[0] # only one document in the export - - serialized = {k: v for k, v in serialized.items() if k not in IGNORE_KEYS} - - assert serialized.keys() == payload.keys() - for key in payload.keys(): - if key == 'Label': - serialized_no_classes = [{ - k: v for k, v in dic.items() if k != 'classifications' - } for dic in serialized[key]['objects']] - serialized_round = [ - round_dict(dic) for dic in serialized_no_classes - ] - payload_round = [round_dict(dic) for dic in payload[key]['objects']] - assert payload_round == serialized_round - else: - assert serialized[key] == payload[key] - ----- -tests/data/annotation_types/test_label.py -import numpy as np - -import labelbox.types as lb_types -from labelbox import OntologyBuilder, Tool, Classification as OClassification, Option -from labelbox.data.annotation_types import (ClassificationAnswer, Radio, Text, - ClassificationAnnotation, - ObjectAnnotation, Point, Line, - ImageData, Label) - - -def test_schema_assignment_geometry(): - name = "line_feature" - label = Label( - data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), - annotations=[ - ObjectAnnotation( - value=Line( - points=[Point(x=1, y=2), Point(x=2, y=2)]), - name=name, - ) - ]) - feature_schema_id = "expected_id" - ontology = OntologyBuilder(tools=[ - Tool(Tool.Type.LINE, name=name, feature_schema_id=feature_schema_id) - ]) - label.assign_feature_schema_ids(ontology) - - assert label.annotations[0].feature_schema_id == feature_schema_id - - -def test_schema_assignment_classification(): - radio_name = "radio_name" - text_name = "text_name" - option_name = "my_option" - - label = Label(data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), - annotations=[ - ClassificationAnnotation(value=Radio( - answer=ClassificationAnswer(name=option_name)), - name=radio_name), - ClassificationAnnotation(value=Text(answer="some text"), - name=text_name) - ]) - radio_schema_id = "radio_schema_id" - text_schema_id = "text_schema_id" - option_schema_id = "option_schema_id" - ontology = OntologyBuilder( - tools=[], - classifications=[ - OClassification(class_type=OClassification.Type.RADIO, - name=radio_name, - feature_schema_id=radio_schema_id, - options=[ - Option(value=option_name, - feature_schema_id=option_schema_id) - ]), - OClassification( - class_type=OClassification.Type.TEXT, - name=text_name, - feature_schema_id=text_schema_id, - ) - ]) - label.assign_feature_schema_ids(ontology) - assert label.annotations[0].feature_schema_id == radio_schema_id - assert label.annotations[1].feature_schema_id == text_schema_id - assert label.annotations[ - 0].value.answer.feature_schema_id == option_schema_id - - -def test_schema_assignment_subclass(): - name = "line_feature" - radio_name = "radio_name" - option_name = "my_option" - classification = ClassificationAnnotation( - name=radio_name, - value=Radio(answer=ClassificationAnswer(name=option_name)), - ) - label = Label( - data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), - annotations=[ - ObjectAnnotation(value=Line( - points=[Point(x=1, y=2), Point(x=2, y=2)]), - name=name, - classifications=[classification]) - ]) - feature_schema_id = "expected_id" - classification_schema_id = "classification_id" - option_schema_id = "option_schema_id" - ontology = OntologyBuilder(tools=[ - Tool(Tool.Type.LINE, - name=name, - feature_schema_id=feature_schema_id, - classifications=[ - OClassification(class_type=OClassification.Type.RADIO, - name=radio_name, - feature_schema_id=classification_schema_id, - options=[ - Option(value=option_name, - feature_schema_id=option_schema_id) - ]) - ]) - ]) - label.assign_feature_schema_ids(ontology) - assert label.annotations[0].feature_schema_id == feature_schema_id - assert label.annotations[0].classifications[ - 0].feature_schema_id == classification_schema_id - assert label.annotations[0].classifications[ - 0].value.answer.feature_schema_id == option_schema_id - - -def test_highly_nested(): - name = "line_feature" - radio_name = "radio_name" - nested_name = "nested_name" - option_name = "my_option" - nested_option_name = "nested_option_name" - classification = ClassificationAnnotation( - name=radio_name, - value=Radio(answer=ClassificationAnswer(name=option_name)), - classifications=[ - ClassificationAnnotation(value=Radio(answer=ClassificationAnswer( - name=nested_option_name)), - name=nested_name) - ]) - label = Label( - data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), - annotations=[ - ObjectAnnotation(value=Line( - points=[Point(x=1, y=2), Point(x=2, y=2)]), - name=name, - classifications=[classification]) - ]) - feature_schema_id = "expected_id" - classification_schema_id = "classification_id" - nested_classification_schema_id = "nested_classification_schema_id" - option_schema_id = "option_schema_id" - ontology = OntologyBuilder(tools=[ - Tool(Tool.Type.LINE, - name=name, - feature_schema_id=feature_schema_id, - classifications=[ - OClassification( - class_type=OClassification.Type.RADIO, - name=radio_name, - feature_schema_id=classification_schema_id, - options=[ - Option(value=option_name, - feature_schema_id=option_schema_id, - options=[ - OClassification( - class_type=OClassification.Type.RADIO, - name=nested_name, - feature_schema_id= - nested_classification_schema_id, - options=[ - Option( - value=nested_option_name, - feature_schema_id= - nested_classification_schema_id) - ]) - ]) - ]) - ]) - ]) - label.assign_feature_schema_ids(ontology) - assert label.annotations[0].feature_schema_id == feature_schema_id - assert label.annotations[0].classifications[ - 0].feature_schema_id == classification_schema_id - assert label.annotations[0].classifications[ - 0].value.answer.feature_schema_id == option_schema_id - - -def test_schema_assignment_confidence(): - name = "line_feature" - label = Label(data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), - annotations=[ - ObjectAnnotation(value=Line( - points=[Point(x=1, y=2), - Point(x=2, y=2)],), - name=name, - confidence=0.914) - ]) - - assert label.annotations[0].confidence == 0.914 - - -def test_initialize_label_no_coercion(): - global_key = 'global-key' - ner_annotation = lb_types.ObjectAnnotation( - name="ner", - value=lb_types.ConversationEntity(start=0, end=8, message_id="4")) - label = Label(data=lb_types.ConversationData(global_key=global_key), - annotations=[ner_annotation]) - assert isinstance(label.data, lb_types.ConversationData) - assert label.data.global_key == global_key - ----- -tests/data/annotation_types/test_metrics.py -import pytest - -from labelbox.data.annotation_types.metrics import ConfusionMatrixAggregation, ScalarMetricAggregation -from labelbox.data.annotation_types.metrics import ConfusionMatrixMetric, ScalarMetric -from labelbox.data.annotation_types import ScalarMetric, Label, ImageData -from labelbox.data.annotation_types.metrics.scalar import RESERVED_METRIC_NAMES -from labelbox import pydantic_compat - - -def test_legacy_scalar_metric(): - value = 10 - metric = ScalarMetric(value=value) - assert metric.value == value - - label = Label(data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), - annotations=[metric]) - expected = { - 'data': { - 'external_id': None, - 'uid': 'ckrmd9q8g000009mg6vej7hzg', - 'global_key': None, - 'im_bytes': None, - 'file_path': None, - 'url': None, - 'arr': None, - 'media_attributes': None, - 'metadata': None, - }, - 'annotations': [{ - 'value': 10.0, - 'extra': {}, - }], - 'extra': {}, - 'uid': None - } - assert label.dict() == expected - - -# TODO: Test with confidence - - -@pytest.mark.parametrize('feature_name,subclass_name,aggregation,value', [ - ("cat", "orange", ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5), - ("cat", None, ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5), - (None, None, ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5), - (None, None, None, 0.5), - ("cat", "orange", ScalarMetricAggregation.ARITHMETIC_MEAN, 0.5), - ("cat", None, ScalarMetricAggregation.HARMONIC_MEAN, 0.5), - (None, None, ScalarMetricAggregation.GEOMETRIC_MEAN, 0.5), - (None, None, ScalarMetricAggregation.SUM, 0.5), - ("cat", "orange", ScalarMetricAggregation.ARITHMETIC_MEAN, { - 0.1: 0.2, - 0.3: 0.5, - 0.4: 0.8 - }), -]) -def test_custom_scalar_metric(feature_name, subclass_name, aggregation, value): - kwargs = {'aggregation': aggregation} if aggregation is not None else {} - metric = ScalarMetric(metric_name="custom_iou", - value=value, - feature_name=feature_name, - subclass_name=subclass_name, - **kwargs) - assert metric.value == value - - label = Label(data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), - annotations=[metric]) - expected = { - 'data': { - 'external_id': None, - 'uid': 'ckrmd9q8g000009mg6vej7hzg', - 'global_key': None, - 'im_bytes': None, - 'file_path': None, - 'url': None, - 'arr': None, - 'media_attributes': None, - 'metadata': None, - }, - 'annotations': [{ - 'value': - value, - 'metric_name': - 'custom_iou', - **({ - 'feature_name': feature_name - } if feature_name else {}), - **({ - 'subclass_name': subclass_name - } if subclass_name else {}), 'aggregation': - aggregation or ScalarMetricAggregation.ARITHMETIC_MEAN, - 'extra': {} - }], - 'extra': {}, - 'uid': None - } - - assert label.dict() == expected - - -@pytest.mark.parametrize('feature_name,subclass_name,aggregation,value', [ - ("cat", "orange", ConfusionMatrixAggregation.CONFUSION_MATRIX, - (0, 1, 2, 3)), - ("cat", None, ConfusionMatrixAggregation.CONFUSION_MATRIX, (0, 1, 2, 3)), - (None, None, ConfusionMatrixAggregation.CONFUSION_MATRIX, (0, 1, 2, 3)), - (None, None, None, (0, 1, 2, 3)), - ("cat", "orange", ConfusionMatrixAggregation.CONFUSION_MATRIX, { - 0.1: (0, 1, 2, 3), - 0.3: (0, 1, 2, 3), - 0.4: (0, 1, 2, 3) - }), -]) -def test_custom_confusison_matrix_metric(feature_name, subclass_name, - aggregation, value): - kwargs = {'aggregation': aggregation} if aggregation is not None else {} - metric = ConfusionMatrixMetric(metric_name="confusion_matrix_50_pct_iou", - value=value, - feature_name=feature_name, - subclass_name=subclass_name, - **kwargs) - assert metric.value == value - - label = Label(data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), - annotations=[metric]) - expected = { - 'data': { - 'external_id': None, - 'uid': 'ckrmd9q8g000009mg6vej7hzg', - 'global_key': None, - 'im_bytes': None, - 'file_path': None, - 'url': None, - 'arr': None, - 'media_attributes': None, - 'metadata': None, - }, - 'annotations': [{ - 'value': - value, - 'metric_name': - 'confusion_matrix_50_pct_iou', - **({ - 'feature_name': feature_name - } if feature_name else {}), - **({ - 'subclass_name': subclass_name - } if subclass_name else {}), 'aggregation': - aggregation or ConfusionMatrixAggregation.CONFUSION_MATRIX, - 'extra': {} - }], - 'extra': {}, - 'uid': None - } - assert label.dict() == expected - - -def test_name_exists(): - # Name is only required for ConfusionMatrixMetric for now. - with pytest.raises(pydantic_compat.ValidationError) as exc_info: - metric = ConfusionMatrixMetric(value=[0, 1, 2, 3]) - assert "field required (type=value_error.missing)" in str(exc_info.value) - - -def test_invalid_aggregations(): - with pytest.raises(pydantic_compat.ValidationError) as exc_info: - metric = ScalarMetric( - metric_name="invalid aggregation", - value=0.1, - aggregation=ConfusionMatrixAggregation.CONFUSION_MATRIX) - assert "value is not a valid enumeration member" in str(exc_info.value) - with pytest.raises(pydantic_compat.ValidationError) as exc_info: - metric = ConfusionMatrixMetric(metric_name="invalid aggregation", - value=[0, 1, 2, 3], - aggregation=ScalarMetricAggregation.SUM) - assert "value is not a valid enumeration member" in str(exc_info.value) - - -def test_invalid_number_of_confidence_scores(): - with pytest.raises(pydantic_compat.ValidationError) as exc_info: - metric = ScalarMetric(metric_name="too few scores", value={0.1: 0.1}) - assert "Number of confidence scores must be greater" in str(exc_info.value) - with pytest.raises(pydantic_compat.ValidationError) as exc_info: - metric = ConfusionMatrixMetric(metric_name="too few scores", - value={0.1: [0, 1, 2, 3]}) - assert "Number of confidence scores must be greater" in str(exc_info.value) - with pytest.raises(pydantic_compat.ValidationError) as exc_info: - metric = ScalarMetric(metric_name="too many scores", - value={i / 20.: 0.1 for i in range(20)}) - assert "Number of confidence scores must be greater" in str(exc_info.value) - with pytest.raises(pydantic_compat.ValidationError) as exc_info: - metric = ConfusionMatrixMetric( - metric_name="too many scores", - value={i / 20.: [0, 1, 2, 3] for i in range(20)}) - assert "Number of confidence scores must be greater" in str(exc_info.value) - - -@pytest.mark.parametrize("metric_name", RESERVED_METRIC_NAMES) -def test_reserved_names(metric_name: str): - with pytest.raises(pydantic_compat.ValidationError) as exc_info: - ScalarMetric(metric_name=metric_name, value=0.5) - assert 'is a reserved metric name' in exc_info.value.errors()[0]['msg'] - ----- -tests/data/annotation_types/__init__.py - ----- -tests/data/annotation_types/test_video.py -import labelbox.types as lb_types - - -def test_mask_frame(): - mask_frame = lb_types.MaskFrame(index=1, - instance_uri="http://path/to/frame.png") - assert mask_frame.dict(by_alias=True) == { - 'index': 1, - 'imBytes': None, - 'instanceURI': 'http://path/to/frame.png' - } - - -def test_mask_instance(): - mask_instance = lb_types.MaskInstance(color_rgb=(0, 0, 255), name="mask1") - assert mask_instance.dict(by_alias=True) == { - 'colorRGB': (0, 0, 255), - 'name': 'mask1' - } - ----- -tests/data/annotation_types/test_collection.py -from types import SimpleNamespace -from uuid import uuid4 - -import numpy as np -import pytest - -from labelbox.data.annotation_types import (LabelGenerator, ObjectAnnotation, - ImageData, MaskData, Line, Mask, - Point, Label) -from labelbox import OntologyBuilder, Tool - - -@pytest.fixture -def list_of_labels(): - return [Label(data=ImageData(url="http://someurl")) for _ in range(5)] - - -@pytest.fixture -def signer(): - - def get_signer(uuid): - return lambda x: uuid - - return get_signer - - -class FakeDataset: - - def __init__(self): - self.uid = "ckrb4tgm51xl10ybc7lv9ghm7" - self.exports = [] - - def create_data_row(self, row_data, external_id=None): - if external_id is None: - external_id = "an external_id" - return SimpleNamespace(uid=self.uid, external_id=external_id) - - def create_data_rows(self, args): - for arg in args: - self.exports.append( - SimpleNamespace(row_data=arg['row_data'], - external_id=arg['external_id'], - uid=self.uid)) - return self - - def wait_till_done(self): - pass - - def export_data_rows(self): - for export in self.exports: - yield export - - -def test_generator(list_of_labels): - generator = LabelGenerator([list_of_labels[0]]) - - assert next(generator) == list_of_labels[0] - with pytest.raises(StopIteration): - next(generator) - - -def test_conversion(list_of_labels): - generator = LabelGenerator(list_of_labels) - label_collection = list(generator) - assert len(label_collection) == len(list_of_labels) - assert [x for x in label_collection] == list_of_labels - - -def test_adding_schema_ids(): - name = "line_feature" - label = Label( - data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), - annotations=[ - ObjectAnnotation( - value=Line( - points=[Point(x=1, y=2), Point(x=2, y=2)]), - name=name, - ) - ]) - feature_schema_id = "expected_id" - ontology = OntologyBuilder(tools=[ - Tool(Tool.Type.LINE, name=name, feature_schema_id=feature_schema_id) - ]) - generator = LabelGenerator([label]).assign_feature_schema_ids(ontology) - assert next(generator).annotations[0].feature_schema_id == feature_schema_id - - -def test_adding_urls(signer): - label = Label(data=ImageData(arr=np.random.random((32, 32, - 3)).astype(np.uint8)), - annotations=[]) - uuid = str(uuid4()) - generator = LabelGenerator([label]).add_url_to_data(signer(uuid)) - assert label.data.url != uuid - assert next(generator).data.url == uuid - assert label.data.url == uuid - - -def test_adding_to_dataset(signer): - dataset = FakeDataset() - label = Label(data=ImageData(arr=np.random.random((32, 32, - 3)).astype(np.uint8)), - annotations=[]) - uuid = str(uuid4()) - generator = LabelGenerator([label]).add_to_dataset(dataset, signer(uuid)) - assert label.data.url != uuid - generated_label = next(generator) - assert generated_label.data.url == uuid - assert generated_label.data.external_id != None - assert generated_label.data.uid == dataset.uid - assert label.data.url == uuid - - -def test_adding_to_masks(signer): - label = Label( - data=ImageData(arr=np.random.random((32, 32, 3)).astype(np.uint8)), - annotations=[ - ObjectAnnotation(name="1234", - value=Mask(mask=MaskData( - arr=np.random.random((32, 32, - 3)).astype(np.uint8)), - color=[255, 255, 255])) - ]) - uuid = str(uuid4()) - generator = LabelGenerator([label]).add_url_to_masks(signer(uuid)) - assert label.annotations[0].value.mask.url != uuid - assert next(generator).annotations[0].value.mask.url == uuid - assert label.annotations[0].value.mask.url == uuid - ----- -tests/data/annotation_types/test_text.py -from labelbox.data.annotation_types.classification.classification import Text - - -def test_text(): - text_entity = Text(answer="good job") - assert text_entity.answer == "good job" - - -def test_text_confidence(): - text_entity = Text(answer="good job", confidence=0.5) - assert text_entity.answer == "good job" - assert text_entity.confidence == 0.5 - ----- -tests/data/annotation_types/test_tiled_image.py -import pytest -from labelbox.data.annotation_types.geometry.polygon import Polygon -from labelbox.data.annotation_types.geometry.point import Point -from labelbox.data.annotation_types.geometry.line import Line -from labelbox.data.annotation_types.geometry.rectangle import Rectangle -from labelbox.data.annotation_types.data.tiled_image import (EPSG, TiledBounds, - TileLayer, - TiledImageData, - EPSGTransformer) -from labelbox import pydantic_compat - - -@pytest.mark.parametrize("epsg", list(EPSG)) -def test_epsg(epsg): - assert isinstance(epsg, EPSG) - - -@pytest.mark.parametrize("epsg", list(EPSG)) -def test_tiled_bounds(epsg): - top_left = Point(x=0, y=0) - bottom_right = Point(x=50, y=50) - - tiled_bounds = TiledBounds(epsg=epsg, bounds=[top_left, bottom_right]) - assert isinstance(tiled_bounds, TiledBounds) - assert isinstance(tiled_bounds.epsg, EPSG) - - -@pytest.mark.parametrize("epsg", list(EPSG)) -def test_tiled_bounds_same(epsg): - single_bound = Point(x=0, y=0) - with pytest.raises(pydantic_compat.ValidationError): - tiled_bounds = TiledBounds(epsg=epsg, - bounds=[single_bound, single_bound]) - - -def test_create_tiled_image_data(): - bounds_points = [Point(x=0, y=0), Point(x=5, y=5)] - url = "https://labelbox.s3-us-west-2.amazonaws.com/pathology/{z}/{x}/{y}.png" - zoom_levels = (1, 10) - - tile_layer = TileLayer(url=url, name="slippy map tile") - tile_bounds = TiledBounds(epsg=EPSG.EPSG4326, bounds=bounds_points) - tiled_image_data = TiledImageData(tile_layer=tile_layer, - tile_bounds=tile_bounds, - zoom_levels=zoom_levels, - version=2) - assert isinstance(tiled_image_data, TiledImageData) - assert tiled_image_data.tile_bounds.bounds == bounds_points - assert tiled_image_data.tile_layer.url == url - assert tiled_image_data.zoom_levels == zoom_levels - - -def test_epsg_point_projections(): - zoom = 4 - - bounds_simple = TiledBounds(epsg=EPSG.SIMPLEPIXEL, - bounds=[Point(x=0, y=0), - Point(x=256, y=256)]) - - bounds_3857 = TiledBounds(epsg=EPSG.EPSG3857, - bounds=[ - Point(x=-104.150390625, y=30.789036751261136), - Point(x=-81.8701171875, y=45.920587344733654) - ]) - bounds_4326 = TiledBounds(epsg=EPSG.EPSG4326, - bounds=[ - Point(x=-104.150390625, y=30.789036751261136), - Point(x=-81.8701171875, y=45.920587344733654) - ]) - - point = Point(x=-11016716.012685884, y=5312679.21393289) - point_two = Point(x=-12016716.012685884, y=5212679.21393289) - point_three = Point(x=-13016716.012685884, y=5412679.21393289) - - line = Line(points=[point, point_two, point_three]) - polygon = Polygon(points=[point, point_two, point_three]) - rectangle = Rectangle(start=point, end=point_three) - - shapes_to_test = [point, line, polygon, rectangle] - - transformer_3857_simple = EPSGTransformer.create_geo_to_pixel_transformer( - src_epsg=EPSG.EPSG3857, - pixel_bounds=bounds_simple, - geo_bounds=bounds_3857, - zoom=zoom) - transformer_3857_4326 = EPSGTransformer.create_geo_to_geo_transformer( - src_epsg=EPSG.EPSG3857, - tgt_epsg=EPSG.EPSG4326, - ) - transformer_4326_simple = EPSGTransformer.create_geo_to_pixel_transformer( - src_epsg=EPSG.EPSG4326, - pixel_bounds=bounds_simple, - geo_bounds=bounds_4326, - zoom=zoom) - - for shape in shapes_to_test: - shape_simple = transformer_3857_simple(shape=shape) - - shape_4326 = transformer_3857_4326(shape=shape) - - other_simple_shape = transformer_4326_simple(shape=shape_4326) - - assert shape_simple == other_simple_shape - ----- -tests/data/annotation_types/test_annotation.py -import pytest - -from labelbox.data.annotation_types import (Text, Point, Line, - ClassificationAnnotation, - ObjectAnnotation, TextEntity) -from labelbox.data.annotation_types.video import VideoObjectAnnotation -from labelbox.data.annotation_types.geometry.rectangle import Rectangle -from labelbox.data.annotation_types.video import VideoClassificationAnnotation -from labelbox.exceptions import ConfidenceNotSupportedException -from labelbox import pydantic_compat - - -def test_annotation(): - name = "line_feature" - line = Line(points=[Point(x=1, y=2), Point(x=2, y=2)]) - classification = Text(answer="1234") - - annotation = ObjectAnnotation( - value=line, - name=name, - ) - assert annotation.value.points[0].dict() == {'extra': {}, 'x': 1., 'y': 2.} - assert annotation.name == name - - # Check ner - ObjectAnnotation( - value=TextEntity(start=10, end=12), - name=name, - ) - - # Check classification - ClassificationAnnotation( - value=classification, - name=name, - ) - - # Invalid subclass - with pytest.raises(pydantic_compat.ValidationError): - ObjectAnnotation( - value=line, - name=name, - classifications=[line], - ) - - subclass = ClassificationAnnotation(value=classification, name=name) - - ObjectAnnotation( - value=line, - name=name, - classifications=[subclass], - ) - - -def test_video_annotations(): - name = "line_feature" - line = Line(points=[Point(x=1, y=2), Point(x=2, y=2)]) - - # Wrong type - with pytest.raises(pydantic_compat.ValidationError): - VideoClassificationAnnotation(value=line, name=name, frame=1) - - # Missing frames - with pytest.raises(pydantic_compat.ValidationError): - VideoClassificationAnnotation(value=line, name=name) - - VideoObjectAnnotation(value=line, name=name, keyframe=True, frame=2) - - -def test_confidence_for_video_is_not_supported(): - with pytest.raises(ConfidenceNotSupportedException): - VideoObjectAnnotation(name='bbox toy', - feature_schema_id='ckz38ofop0mci0z9i9w3aa9o4', - extra={ - 'value': 'bbox_toy', - 'instanceURI': None, - 'color': '#1CE6FF', - 'feature_id': 'cl1z52xw700000fhcayaqy0ev' - }, - value=Rectangle(extra={}, - start=Point(extra={}, - x=70.0, - y=26.5), - end=Point(extra={}, - x=561.0, - y=348.0)), - classifications=[], - frame=24, - keyframe=False, - confidence=0.3434), - - -def test_confidence_value_range_validation(): - name = "line_feature" - line = Line(points=[Point(x=1, y=2), Point(x=2, y=2)]) - - with pytest.raises(ValueError) as e: - ObjectAnnotation(value=line, name=name, confidence=14) - assert e.value.errors()[0]['msg'] == 'must be a number within [0,1] range' - ----- -tests/data/annotation_types/test_ner.py -from labelbox.data.annotation_types import TextEntity, DocumentEntity, DocumentTextSelection -from labelbox.data.annotation_types.ner.conversation_entity import ConversationEntity - - -def test_ner(): - start = 10 - end = 12 - text_entity = TextEntity(start=start, end=end) - assert text_entity.start == start - assert text_entity.end == end - - -def test_document_entity(): - document_entity = DocumentEntity(text_selections=[ - DocumentTextSelection(token_ids=["1", "2"], group_id="1", page=1) - ]) - - assert document_entity.text_selections[0].token_ids == ["1", "2"] - assert document_entity.text_selections[0].group_id == "1" - assert document_entity.text_selections[0].page == 1 - - -def test_conversation_entity(): - conversation_entity = ConversationEntity(message_id=1, start=0, end=1) - - assert conversation_entity.message_id == "1" - assert conversation_entity.start == 0 - assert conversation_entity.end == 1 - ----- -tests/data/annotation_types/classification/__init__.py - ----- -tests/data/annotation_types/classification/test_classification.py -import pytest - -from labelbox.data.annotation_types import (Checklist, ClassificationAnswer, - Dropdown, Radio, Text, - ClassificationAnnotation) - -from labelbox import pydantic_compat - - -def test_classification_answer(): - with pytest.raises(pydantic_compat.ValidationError): - ClassificationAnswer() - - feature_schema_id = "schema_id" - name = "my_feature" - confidence = 0.9 - custom_metrics = [{'name': 'metric1', 'value': 2}] - answer = ClassificationAnswer(name=name, - confidence=confidence, - custom_metrics=custom_metrics) - - assert answer.feature_schema_id is None - assert answer.name == name - assert answer.confidence == confidence - assert answer.custom_metrics == custom_metrics - - answer = ClassificationAnswer(feature_schema_id=feature_schema_id, - name=name) - - assert answer.feature_schema_id == feature_schema_id - assert answer.name == name - - -def test_classification(): - answer = "1234" - classification = ClassificationAnnotation(value=Text(answer=answer), - name="a classification") - assert classification.dict()['value']['answer'] == answer - - with pytest.raises(pydantic_compat.ValidationError): - ClassificationAnnotation() - - -def test_subclass(): - answer = "1234" - feature_schema_id = "11232" - name = "my_feature" - with pytest.raises(pydantic_compat.ValidationError): - # Should have feature schema info - classification = ClassificationAnnotation(value=Text(answer=answer)) - classification = ClassificationAnnotation(value=Text(answer=answer), - name=name) - assert classification.dict() == { - 'name': name, - 'feature_schema_id': None, - 'extra': {}, - 'value': { - 'answer': answer, - }, - 'message_id': None, - } - classification = ClassificationAnnotation( - value=Text(answer=answer), - name=name, - feature_schema_id=feature_schema_id) - assert classification.dict() == { - 'name': None, - 'feature_schema_id': feature_schema_id, - 'extra': {}, - 'value': { - 'answer': answer, - }, - 'name': name, - 'message_id': None, - } - classification = ClassificationAnnotation( - value=Text(answer=answer), - feature_schema_id=feature_schema_id, - name=name) - assert classification.dict() == { - 'name': name, - 'feature_schema_id': feature_schema_id, - 'extra': {}, - 'value': { - 'answer': answer, - }, - 'message_id': None, - } - - -def test_radio(): - answer = ClassificationAnswer(name="1", - confidence=0.81, - custom_metrics=[{ - 'name': 'metric1', - 'value': 0.99 - }]) - feature_schema_id = "feature_schema_id" - name = "my_feature" - - with pytest.raises(pydantic_compat.ValidationError): - classification = ClassificationAnnotation(value=Radio( - answer=answer.name)) - - with pytest.raises(pydantic_compat.ValidationError): - classification = Radio(answer=[answer]) - classification = Radio(answer=answer,) - assert classification.dict() == { - 'answer': { - 'name': answer.name, - 'feature_schema_id': None, - 'extra': {}, - 'confidence': 0.81, - 'custom_metrics': [{ - 'name': 'metric1', - 'value': 0.99 - }], - } - } - classification = ClassificationAnnotation( - value=Radio(answer=answer), - feature_schema_id=feature_schema_id, - name=name, - custom_metrics=[{ - 'name': 'metric1', - 'value': 0.99 - }]) - assert classification.dict() == { - 'name': name, - 'feature_schema_id': feature_schema_id, - 'extra': {}, - 'custom_metrics': [{ - 'name': 'metric1', - 'value': 0.99 - }], - 'value': { - 'answer': { - 'name': answer.name, - 'feature_schema_id': None, - 'extra': {}, - 'confidence': 0.81, - 'custom_metrics': [{ - 'name': 'metric1', - 'value': 0.99 - }] - }, - }, - 'message_id': None, - } - - -def test_checklist(): - answer = ClassificationAnswer(name="1", - confidence=0.99, - custom_metrics=[{ - 'name': 'metric1', - 'value': 2 - }]) - feature_schema_id = "feature_schema_id" - name = "my_feature" - - with pytest.raises(pydantic_compat.ValidationError): - classification = Checklist(answer=answer.name) - - with pytest.raises(pydantic_compat.ValidationError): - classification = Checklist(answer=answer) - - classification = Checklist(answer=[answer]) - assert classification.dict() == { - 'answer': [{ - 'name': answer.name, - 'feature_schema_id': None, - 'extra': {}, - 'confidence': 0.99, - 'custom_metrics': [{ - 'name': 'metric1', - 'value': 2 - }], - }] - } - classification = ClassificationAnnotation( - value=Checklist(answer=[answer]), - feature_schema_id=feature_schema_id, - name=name, - ) - assert classification.dict() == { - 'name': name, - 'feature_schema_id': feature_schema_id, - 'extra': {}, - 'value': { - 'answer': [{ - 'name': answer.name, - 'feature_schema_id': None, - 'extra': {}, - 'confidence': 0.99, - 'custom_metrics': [{ - 'name': 'metric1', - 'value': 2 - }], - }] - }, - 'message_id': None, - } - - -def test_dropdown(): - answer = ClassificationAnswer(name="1", confidence=1) - feature_schema_id = "feature_schema_id" - name = "my_feature" - - with pytest.raises(pydantic_compat.ValidationError): - classification = ClassificationAnnotation( - value=Dropdown(answer=answer.name), name="test") - - with pytest.raises(pydantic_compat.ValidationError): - classification = Dropdown(answer=answer) - classification = Dropdown(answer=[answer]) - assert classification.dict() == { - 'answer': [{ - 'name': '1', - 'feature_schema_id': None, - 'extra': {}, - 'confidence': 1 - }] - } - classification = ClassificationAnnotation( - value=Dropdown(answer=[answer]), - feature_schema_id=feature_schema_id, - name=name) - assert classification.dict() == { - 'name': name, - 'feature_schema_id': feature_schema_id, - 'extra': {}, - 'value': { - 'answer': [{ - 'name': answer.name, - 'feature_schema_id': None, - 'confidence': 1, - 'extra': {} - }] - }, - 'message_id': None, - } - ----- -tests/data/annotation_types/geometry/test_rectangle.py -import cv2 -import pytest - -from labelbox.data.annotation_types import Point, Rectangle -from labelbox import pydantic_compat - - -def test_rectangle(): - with pytest.raises(pydantic_compat.ValidationError): - rectangle = Rectangle() - - rectangle = Rectangle(start=Point(x=0, y=1), end=Point(x=10, y=10)) - points = [[[0.0, 1.0], [0.0, 10.0], [10.0, 10.0], [10.0, 1.0], [0.0, 1.0]]] - expected = {"coordinates": points, "type": "Polygon"} - assert rectangle.geometry == expected - expected['coordinates'] = tuple([tuple([tuple(x) for x in points[0]])]) - assert rectangle.shapely.__geo_interface__ == expected - - raster = rectangle.draw(height=32, width=32) - assert (cv2.imread("tests/data/assets/rectangle.png") == raster).all() - - xyhw = Rectangle.from_xyhw(0., 0, 10, 10) - assert xyhw.start == Point(x=0, y=0.) - assert xyhw.end == Point(x=10, y=10.0) - ----- -tests/data/annotation_types/geometry/test_point.py -import pytest -import cv2 - -from labelbox.data.annotation_types import Point -from labelbox import pydantic_compat - - -def test_point(): - with pytest.raises(pydantic_compat.ValidationError): - line = Point() - - with pytest.raises(TypeError): - line = Point([0, 1]) - - point = Point(x=0, y=1) - expected = {"coordinates": [0, 1], "type": "Point"} - assert point.geometry == expected - expected['coordinates'] = tuple(expected['coordinates']) - assert point.shapely.__geo_interface__ == expected - - raster = point.draw(height=32, width=32, thickness=1) - assert (cv2.imread("tests/data/assets/point.png") == raster).all() - ----- -tests/data/annotation_types/geometry/test_mask.py -import pytest - -import numpy as np -import cv2 - -from labelbox.data.annotation_types import Point, Rectangle, Mask, MaskData -from labelbox import pydantic_compat - - -def test_mask(): - with pytest.raises(pydantic_compat.ValidationError): - mask = Mask() - - mask_data = np.zeros((32, 32, 3), dtype=np.uint8) - mask_data = cv2.rectangle(mask_data, (0, 0), (10, 10), (255, 255, 255), -1) - mask_data = cv2.rectangle(mask_data, (20, 20), (30, 30), (0, 255, 255), -1) - mask_data = MaskData(arr=mask_data) - - mask1 = Mask(mask=mask_data, color=(255, 255, 255)) - - expected1 = { - 'type': - 'MultiPolygon', - 'coordinates': [ - (((0.0, 0.0), (0.0, 1.0), (0.0, 2.0), (0.0, 3.0), (0.0, 4.0), (0.0, - 5.0), - (0.0, 6.0), (0.0, 7.0), (0.0, 8.0), (0.0, 9.0), (0.0, 10.0), - (1.0, 10.0), (2.0, 10.0), (3.0, 10.0), (4.0, 10.0), (5.0, 10.0), - (6.0, 10.0), (7.0, 10.0), (8.0, 10.0), (9.0, 10.0), (10.0, 10.0), - (10.0, 9.0), (10.0, 8.0), (10.0, 7.0), (10.0, 6.0), (10.0, 5.0), - (10.0, 4.0), (10.0, 3.0), (10.0, 2.0), (10.0, 1.0), (10.0, 0.0), - (9.0, 0.0), (8.0, 0.0), (7.0, 0.0), (6.0, 0.0), (5.0, 0.0), - (4.0, 0.0), (3.0, 0.0), (2.0, 0.0), (1.0, 0.0), (0.0, 0.0)),) - ] - } - assert mask1.geometry == expected1 - assert mask1.shapely.__geo_interface__ == expected1 - - mask2 = Mask(mask=mask_data, color=(0, 255, 255)) - expected2 = { - 'type': - 'MultiPolygon', - 'coordinates': [ - (((20.0, 20.0), (20.0, 21.0), (20.0, 22.0), (20.0, 23.0), - (20.0, 24.0), (20.0, 25.0), (20.0, 26.0), (20.0, 27.0), - (20.0, 28.0), (20.0, 29.0), (20.0, 30.0), (21.0, 30.0), - (22.0, 30.0), (23.0, 30.0), (24.0, 30.0), (25.0, 30.0), - (26.0, 30.0), (27.0, 30.0), (28.0, 30.0), (29.0, 30.0), - (30.0, 30.0), (30.0, 29.0), (30.0, 28.0), (30.0, 27.0), - (30.0, 26.0), (30.0, 25.0), (30.0, 24.0), (30.0, 23.0), - (30.0, 22.0), (30.0, 21.0), (30.0, 20.0), (29.0, 20.0), - (28.0, 20.0), (27.0, 20.0), (26.0, 20.0), (25.0, 20.0), - (24.0, 20.0), (23.0, 20.0), (22.0, 20.0), (21.0, 20.0), (20.0, - 20.0)),) - ] - } - assert mask2.geometry == expected2 - assert mask2.shapely.__geo_interface__ == expected2 - gt_mask = cv2.cvtColor(cv2.imread("tests/data/assets/mask.png"), - cv2.COLOR_BGR2RGB) - assert (gt_mask == mask1.mask.arr).all() - assert (gt_mask == mask2.mask.arr).all() - - raster1 = mask1.draw() - raster2 = mask2.draw() - - assert (raster1 != raster2).any() - - gt1 = Rectangle(start=Point(x=0, y=0), - end=Point(x=10, y=10)).draw(height=raster1.shape[0], - width=raster1.shape[1], - color=(255, 255, 255)) - gt2 = Rectangle(start=Point(x=20, y=20), - end=Point(x=30, y=30)).draw(height=raster2.shape[0], - width=raster2.shape[1], - color=(0, 255, 255)) - assert (raster1 == gt1).all() - assert (raster2 == gt2).all() - ----- -tests/data/annotation_types/geometry/test_line.py -import pytest -import cv2 - -from labelbox.data.annotation_types.geometry import Point, Line -from labelbox import pydantic_compat - - -def test_line(): - with pytest.raises(pydantic_compat.ValidationError): - line = Line() - - with pytest.raises(pydantic_compat.ValidationError): - line = Line(points=[[0, 1], [2, 3]]) - - points = [[0, 1], [0, 2], [2, 2]] - expected = {"coordinates": [points], "type": "MultiLineString"} - line = Line(points=[Point(x=x, y=y) for x, y in points]) - assert line.geometry == expected - expected['coordinates'] = tuple([tuple([tuple(x) for x in points])]) - assert line.shapely.__geo_interface__ == expected - - raster = line.draw(height=32, width=32, thickness=1) - assert (cv2.imread("tests/data/assets/line.png") == raster).all() - ----- -tests/data/annotation_types/geometry/__init__.py - ----- -tests/data/annotation_types/geometry/test_polygon.py -import pytest -import cv2 - -from labelbox.data.annotation_types import Polygon, Point -from labelbox import pydantic_compat - - -def test_polygon(): - with pytest.raises(pydantic_compat.ValidationError): - polygon = Polygon() - - with pytest.raises(pydantic_compat.ValidationError): - polygon = Polygon(points=[[0, 1], [2, 3]]) - - with pytest.raises(pydantic_compat.ValidationError): - polygon = Polygon(points=[Point(x=0, y=1), Point(x=0, y=1)]) - - points = [[0., 1.], [0., 2.], [2., 2.], [2., 0.]] - expected = {"coordinates": [points + [points[0]]], "type": "Polygon"} - polygon = Polygon(points=[Point(x=x, y=y) for x, y in points]) - assert polygon.geometry == expected - expected['coordinates'] = tuple( - [tuple([tuple(x) for x in points + [points[0]]])]) - assert polygon.shapely.__geo_interface__ == expected - - raster = polygon.draw(10, 10) - assert (cv2.imread("tests/data/assets/polygon.png") == raster).all() - ----- -tests/data/annotation_types/data/test_raster.py -import urllib.request -from io import BytesIO - -import numpy as np -import pytest -from PIL import Image - -from labelbox.data.annotation_types.data import ImageData -from labelbox import pydantic_compat - - -def test_validate_schema(): - with pytest.raises(pydantic_compat.ValidationError): - data = ImageData() - - -def test_im_bytes(): - data = (np.random.random((32, 32, 3)) * 255).astype(np.uint8) - im_bytes = BytesIO() - Image.fromarray(data).save(im_bytes, format="PNG") - raster_data = ImageData(im_bytes=im_bytes.getvalue()) - data_ = raster_data.value - assert np.all(data == data_) - - -def test_im_url(): - raster_data = ImageData(url="https://picsum.photos/id/829/200/300") - data_ = raster_data.value - assert data_.shape == (300, 200, 3) - - -def test_im_path(): - img_path = "/tmp/img.jpg" - urllib.request.urlretrieve("https://picsum.photos/id/829/200/300", img_path) - raster_data = ImageData(file_path=img_path) - data_ = raster_data.value - assert data_.shape == (300, 200, 3) - - -def test_ref(): - external_id = "external_id" - uid = "uid" - metadata = [] - media_attributes = {} - data = ImageData(im_bytes=b'', - external_id=external_id, - uid=uid, - metadata=metadata, - media_attributes=media_attributes) - assert data.external_id == external_id - assert data.uid == uid - assert data.media_attributes == media_attributes - assert data.metadata == metadata - ----- -tests/data/annotation_types/data/__init__.py - ----- -tests/data/annotation_types/data/test_video.py -import numpy as np -import pytest - -from labelbox.data.annotation_types import VideoData -from labelbox import pydantic_compat - - -def test_validate_schema(): - with pytest.raises(pydantic_compat.ValidationError): - data = VideoData() - - -def test_frames(): - data = { - x: (np.random.random((32, 32, 3)) * 255).astype(np.uint8) - for x in range(5) - } - video_data = VideoData(frames=data) - for idx, frame in video_data.frame_generator(): - assert idx in data - assert np.all(frame == data[idx]) - - -def test_file_path(): - path = 'tests/integration/media/cat.mp4' - raster_data = VideoData(file_path=path) - - with pytest.raises(ValueError): - raster_data[0] - - raster_data.load_frames() - raster_data[0] - - frame_indices = list(raster_data.frames.keys()) - # 29 frames - assert set(frame_indices) == set(list(range(28))) - - -def test_file_url(): - url = "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerMeltdowns.mp4" - raster_data = VideoData(url=url) - - with pytest.raises(ValueError): - raster_data[0] - - raster_data.load_frames() - raster_data[0] - - frame_indices = list(raster_data.frames.keys()) - # 362 frames - assert set(frame_indices) == set(list(range(361))) - - -def test_ref(): - external_id = "external_id" - uid = "uid" - data = { - x: (np.random.random((32, 32, 3)) * 255).astype(np.uint8) - for x in range(5) - } - metadata = [] - media_attributes = {} - data = VideoData(frames=data, - external_id=external_id, - uid=uid, - metadata=metadata, - media_attributes=media_attributes) - assert data.external_id == external_id - assert data.uid == uid - assert data.media_attributes == media_attributes - assert data.metadata == metadata - ----- -tests/data/annotation_types/data/test_text.py -import os - -import pytest - -from labelbox.data.annotation_types import TextData -from labelbox import pydantic_compat - - -def test_validate_schema(): - with pytest.raises(pydantic_compat.ValidationError): - data = TextData() - - -def test_text(): - text = "hello world" - metadata = [] - media_attributes = {} - text_data = TextData(text=text, - metadata=metadata, - media_attributes=media_attributes) - assert text_data.text == text - - -def test_url(): - url = "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample3.txt" - text_data = TextData(url=url) - text = text_data.value - assert len(text) == 3541 - - -def test_file(tmpdir): - content = "foo bar baz" - file = "hello.txt" - dir = tmpdir.mkdir('data') - dir.join(file).write(content) - text_data = TextData(file_path=os.path.join(dir.strpath, file)) - assert len(text_data.value) == len(content) - - -def test_ref(): - external_id = "external_id" - uid = "uid" - metadata = [] - media_attributes = {} - data = TextData(text="hello world", - external_id=external_id, - uid=uid, - metadata=metadata, - media_attributes=media_attributes) - assert data.external_id == external_id - assert data.uid == uid - assert data.media_attributes == media_attributes - assert data.metadata == metadata - ----- -docs/README.md -# Labelbox Python SDK API Documentation - -The Labelbox Python API documentation is generated from source code comments -using Sphinx (https://www.sphinx-doc.org/). - -## Preparing the Sphinx environment - -To generate the documentation install Sphinx and Sphinxcontrib-Napoleon. The -easiest way to do it is using a Python virtual env and pip: - -``` -# create a virtual environment -python3 -m venv labelbox_docs_venv - -# activate the venv -source ./labelbox_docs_venv/bin/activate - -# upgrade venv pip and setuptools -pip install --upgrade pip setuptools - -# install Sphinx and necessary contrib from requriements -pip install -r labelbox_root/docs/requirements.txt - -# install Labelbox dependencies -pip install -r labelbox_root/requirements.txt -``` - -There are other ways to do prepare the environment, but we highly recommend -using a Python virtual environment. - -## Generating Labelbox SDK API documentation - -With the Sphinx environment prepared, enter the docs folder: - -``` -cd labelbox_root/docs/ -``` - -Run the make build tool, instructing it to build docs as HTML: -``` -make html -``` - ----- -docs/source/conf.py -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -import os -import sys - -sys.path.insert(0, os.path.abspath('../..')) - -# -- Project information ----------------------------------------------------- - -project = 'Python SDK reference' -copyright = '2021, Labelbox' -author = 'Labelbox' - -release = '3.65.0' - -# -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [ - 'sphinx.ext.autodoc', 'sphinx.ext.viewcode', 'sphinx.ext.napoleon' -] - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] - -# -- Options for HTML output ------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = 'sphinx_rtd_theme' - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] - -# Prevent the sidebar from collapsing -html_js_files = ['js/prevent_collapse.js'] -html_theme_options = { - "collapse_navigation": False, -} - ----- -examples/format_notebooks.py -import glob -import json -from copy import deepcopy - -from yapf.yapflib.yapf_api import FormatCode - -BANNER_CELL = { - "cell_type": - "markdown", - "id": - "db768cda", - "metadata": {}, - "source": [ - "\n", - " \n", - "" - ] -} - -LINK_CELL = { - "cell_type": - "markdown", - "id": - "cb5611d0", - "metadata": {}, - "source": [ - "\n", "\n", - "\n", "\n", "\n", - "\n", - "" - ] -} - -COLAB_TEMPLATE = "https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/{filename}" -GITHUB_TEMPLATE = "https://github.com/Labelbox/labelbox-python/tree/master/examples/{filename}" - - -def format_cell(source): - for line in source.split('\n'): - if line.strip().startswith(('!', '%')): - return source - return FormatCode(source, style_config="google")[0] - - -def add_headers(file_name): - with open(file_name, 'r') as file: - data = json.load(file) - - colab_path = COLAB_TEMPLATE.format(filename=file_name) - github_path = GITHUB_TEMPLATE.format(filename=file_name) - - link_cell = deepcopy(LINK_CELL) - - link_cell['source'][1] = link_cell['source'][1].format(colab=colab_path) - link_cell['source'][6] = link_cell['source'][6].format(github=github_path) - - data['cells'] = [BANNER_CELL, link_cell] + data['cells'] - - with open(file_name, 'w') as file: - file.write(json.dumps(data, indent=4)) - - print("Formatted", file_name) - - -def format_file(file_name): - with open(file_name, 'r') as file: - data = json.load(file) - - idx = 1 - for cell in data['cells']: - if cell['cell_type'] == 'code': - cell['execution_count'] = idx - if isinstance(cell['source'], list): - cell['source'] = ''.join(cell['source']) - cell['source'] = format_cell(cell['source']) - idx += 1 - if cell['source'].endswith('\n'): - cell['source'] = cell['source'][:-1] - - with open(file_name, 'w') as file: - file.write(json.dumps(data, indent=4)) - print("Formatted", file_name) - - -if __name__ == '__main__': - for file in glob.glob("*/*.ipynb"): - format_file(file) - ----- -examples/test_notebooks.py -""" -Runs example notebooks to ensure that they are not throwing an error. -""" - -import pathlib -import pytest - -import nbformat -from nbconvert.preprocessors import ExecutePreprocessor - -examples_path = pathlib.Path(__file__).parent -notebook_paths = examples_path.glob('**/*.ipynb') -filtered_notebook_paths = [ - path for path in notebook_paths if '.ipynb_checkpoints' not in str(path) -] -relative_notebook_paths = [ - str(p.relative_to(examples_path)) for p in filtered_notebook_paths -] - - -def run_notebook(filename): - with open(filename) as ff: - nb_in = nbformat.read(ff, nbformat.NO_CONVERT) - - ep = ExecutePreprocessor(timeout=1200, kernel_name='python3') - - ep.preprocess(nb_in) - - -SKIP_LIST = [ - 'extras/classification-confusion-matrix.ipynb', - 'label_export/images.ipynb', - 'label_export/text.ipynb', - 'label_export/video.ipynb', - 'annotation_types/converters.ipynb', - 'integrations/detectron2/coco_panoptic.ipynb', - 'integrations/tlt/detectnet_v2_bounding_box.ipynb', - 'basics/datasets.ipynb', - 'basics/data_rows.ipynb', - 'basics/labels.ipynb', - 'basics/data_row_metadata.ipynb', - 'model_diagnostics/custom_metrics_basics.ipynb', - 'basics/user_management.ipynb', - 'integrations/tlt/labelbox_upload.ipynb', - 'model_diagnostics/custom_metrics_demo.ipynb', - 'model_diagnostics/model_diagnostics_demo.ipynb', - 'integrations/databricks/', - 'integrations/detectron2/coco_object.ipynb', - 'project_configuration/webhooks.ipynb', - 'basics/projects.ipynb', - 'model_diagnostics/model_diagnostics_guide.ipynb', -] - - -def skip_notebook(notebook_path): - for skip_path in SKIP_LIST: - if notebook_path.startswith(skip_path): - return True - return False - - -run_notebook_paths = [ - path for path in relative_notebook_paths if not skip_notebook(path) -] - - -@pytest.mark.skip( - 'Need some more work to run reliably, e.g. figuring out how to deal with ' - 'max number of models per org, therefore skipping from CI. However, this ' - 'test can be run locally after updating notebooks to ensure notebooks ' - 'are working.') -@pytest.mark.parametrize("notebook_path", run_notebook_paths) -def test_notebooks_run_without_errors(notebook_path): - run_notebook(examples_path / notebook_path) - ----- -examples/README.md -## Labelbox SDK Examples - -- Learn how to use the SDK by following along -- Run in google colab, view the notebooks on github, or clone the repo and run locally - ---- - -## [Basics](basics) - -| Notebook | Github | Google Colab | -| ----------------- | ---------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| Fundamentals | [Github](basics/basics.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/basics/basics.ipynb) | -| Batches | [Github](basics/batches.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/basics/batches.ipynb) | -| Data Rows | [Github](basics/data_rows.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/basics/data_rows.ipynb) | -| Data Row Metadata | [Github](basics/data_row_metadata.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/basics/data_row_metadata.ipynb) | -| Datasets | [Github](basics/datasets.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/basics/datasets.ipynb) | -| Export data | [Github](exports/export_data.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/exports/export_data.ipynb) | -| Ontologies | [Github](basics/ontologies.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/basics/ontologies.ipynb) | -| Projects | [Github](basics/projects.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/basics/projects.ipynb) | -| User Management | [Github](basics/user_management.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/basics/user_management.ipynb) | - ---- - -## [Model Training](https://docs.labelbox.com/docs/integration-with-model-training-service) - -Train a model using data annotated on Labelbox - -| Notebook | Github | Google Colab | -| ------------------------------- | ----------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| Object Detection (Detectron2) | [Github](integrations/detectron2/coco_object.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/integrations/detectron2/coco_object.ipynb) | -| Panoptic Detection (Detectron2) | [Github](integrations/detectron2/coco_panoptic.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/integrations/detectron2/coco_panoptic.ipynb) | - ---- - -## [Annotation Import (Ground Truth & MAL)](annotation_import) - -| Notebook | Github | Google Colab | Learn more | -| ------------------------------------- | ------------------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- | -| Image Annotation Import | [Github](annotation_import/image.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/annotation_import/image.ipynb) | [Docs](https://docs.labelbox.com/reference/import-image-annotations) | -| Text Annotation Import | [Github](annotation_import/text.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/annotation_import/text.ipynb) | [Docs](https://docs.labelbox.com/reference/import-text-annotations) | -| Tiled Imagery Annotation Import | [Github](annotation_import/tiled.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/annotation_import/tiled.ipynb) | [Docs](https://docs.labelbox.com/reference/import-geospatial-annotations) | -| Video Annotation Import | [Github](annotation_import/video.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/annotation_import/video.ipynb) | [Docs](https://docs.labelbox.com/reference/import-video-annotations) | -| PDF Annotation Import | [Github](annotation_import/pdf.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/annotation_import/pdf.ipynb) | [Docs](https://docs.labelbox.com/reference/import-document-annotations) | -| Audio Annotation Import | [Github](annotation_import/audio.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/annotation_import/audio.ipynb) | [Docs](https://docs.labelbox.com/reference/import-audio-annotations) | -| HTML Annotation Import | [Github](annotation_import/html.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/annotation_import/html.ipynb) | [Docs](https://docs.labelbox.com/reference/import-html-annotations) | -| DICOM Annotation Import | [Github](annotation_import/dicom.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/annotation_import/dicom.ipynb) | [Docs](https://docs.labelbox.com/reference/import-dicom-annotations) | -| Conversational Text Annotation Import | [Github](annotation_import/conversational.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/annotation_import/conversational.ipynb) | [Docs](https://docs.labelbox.com/reference/import-conversational-text-annotations) | - ---- - -## [Project Configuration](project_configuration) - -| Notebook | Github | Google Colab | -| ---------------- | ------------------------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| Project Setup | [Github](project_configuration/project_setup.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/project_configuration/project_setup.ipynb) | -| Queue Management | [Github](project_configuration/queue_management.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/project_configuration/queue_management.ipynb) | -| Webhooks | [Github](project_configuration/webhooks.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/project_configuration/webhooks.ipynb) | - -## [Prediction Upload to a Model Run](prediction_upload) - -| Notebook | Github | Google Colab | Learn more | -| ------------------------------------- | ------------------------------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------- | -| Image Prediction upload | [Github](prediction_upload/image_predictions.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/prediction_upload/image_predictions.ipynb) | [Docs](https://docs.labelbox.com/reference/upload-image-predictions) | -| Text Prediction upload | [Github](prediction_upload/text_predictions.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/prediction_upload/text_predictions.ipynb) | [Docs](https://docs.labelbox.com/reference/upload-text-predictions) | -| Video Prediction upload | [Github](prediction_upload/video_predictions.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/prediction_upload/video_predictions.ipynb) | [Docs](https://docs.labelbox.com/reference/upload-video-predictions) | -| HTML Prediction upload | [Github](prediction_upload/html_predictions.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/prediction_upload/html_predictions.ipynb) | [Docs](https://docs.labelbox.com/reference/upload-html-predictions) | -| PDF Prediction upload | [Github](prediction_upload/pdf_predictions.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/prediction_upload/pdf_predictions.ipynb) | -| Geospatial Prediction upload | [Github](prediction_upload/geospatial_predictions.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/prediction_upload/geospatial_predictions.ipynb) | [Docs](https://docs.labelbox.com/reference/upload-geospatial-predictions) | -| Conversational Text Prediction upload | [Github](prediction_upload/conversational_predictions.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/prediction_upload/conversational_predictions.ipynb) | - ---- - -## [Extras](extras) - -| Notebook | Github | Google Colab | -| ------------------------------- | ------------------------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| Classification Confusion Matrix | [Github](extras/classification-confusion-matrix.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/extras/classification-confusion-matrix.ipynb) | - ----- -examples/project_configuration/README.md -- This section describes advanced usage of project configuration. -- Make sure you are comfortable with the basics/project section first. - ----- -examples/integrations/databricks/readme.md -# Databricks + Labelbox - -##### Use the Labelbox Connector to easily work with unstructured data in Databricks - --------- - - -#### [Tutorial Notebook](labelbox_databricks_example.ipynb) -* Load DataFrame of unstructured data (URLs to video, images, or text) -* Create the dataset in Labelbox -* Annotate in Labelbox -* Load annotations into Databricks for easy querying and model training - -#### [API Key Notebook](api_key_db_template.ipynb) -* This is a helper notebook for users without access to the Databricks Secrets API -* Allows you to store your Labelbox API key outside of your main notebook, for better security -* We do recommend you use the Secrets API whenever possible - -More information about the Connector is available on [PyPI](https://pypi.org/project/labelspark/) - -[Connector Source Code](https://github.com/Labelbox/labelspark/) - ----- -examples/integrations/databricks/labelbox_databricks_example.py -# Databricks notebook source -# MAGIC %md -# MAGIC # Labelbox Connector for Databricks Tutorial Notebook - -# COMMAND ---------- - -# MAGIC %md -# MAGIC #### Pre-requisites -# MAGIC 1. This tutorial notebook requires a Lablbox API Key. Please login to your [Labelbox Account](app.labelbox.com) and generate an [API Key](https://app.labelbox.com/account/api-keys) -# MAGIC 2. A few cells below will install the Labelbox SDK and Connector Library. This install is notebook-scoped and will not affect the rest of your cluster. -# MAGIC 3. Please make sure you are running at least the latest LTS version of Databricks. -# MAGIC -# MAGIC #### Notebook Preview -# MAGIC This notebook will guide you through these steps: -# MAGIC 1. Connect to Labelbox via the SDK -# MAGIC 2. Create a labeling dataset from a table of unstructured data in Databricks -# MAGIC 3. Programmatically set up an ontology and labeling project in Labelbox -# MAGIC 4. Load Bronze and Silver annotation tables from an example labeled project -# MAGIC 5. Additional cells describe how to handle video annotations and use Labelbox Diagnostics and Catalog -# MAGIC -# MAGIC Additional documentation links are provided at the end of the notebook. - -# COMMAND ---------- - -# MAGIC %md -# MAGIC Thanks for trying out the Databricks and Labelbox Connector! You or someone from your organization signed up for a Labelbox trial through Databricks Partner Connect. This notebook was loaded into your Shared directory to help illustrate how Labelbox and Databricks can be used together to power unstructured data workflows. -# MAGIC -# MAGIC Labelbox can be used to rapidly annotate a variety of unstructured data from your Data Lake ([images](https://labelbox.com/product/image), [video](https://labelbox.com/product/video), [text](https://labelbox.com/product/text), and [geospatial tiled imagery](https://docs.labelbox.com/docs/tiled-imagery-editor)) and the Labelbox Connector for Databricks makes it easy to bring the annotations back into your Lakehouse environment for AI/ML and analytical workflows. -# MAGIC -# MAGIC If you would like to watch a video of the workflow, check out our [Data & AI Summit Demo](https://databricks.com/session_na21/productionizing-unstructured-data-for-ai-and-analytics). -# MAGIC -# MAGIC -# MAGIC example-workflow -# MAGIC -# MAGIC
Questions or comments? Reach out to us at [ecosystem+databricks@labelbox.com](mailto:ecosystem+databricks@labelbox.com) - -# COMMAND ---------- - -# DBTITLE 1,Install Labelbox Library & Labelbox Connector for Databricks -# MAGIC %pip install labelbox labelspark - -# COMMAND ---------- - -#This will import Koalas or Pandas-on-Spark based on your DBR version. -from pyspark import SparkContext -from packaging import version - -sc = SparkContext.getOrCreate() -if version.parse(sc.version) < version.parse("3.2.0"): - import databricks.koalas as pd - needs_koalas = True -else: - import pyspark.pandas as pd - needs_koalas = False - -# COMMAND ---------- - -# MAGIC %md -# MAGIC ## Configure the SDK -# MAGIC -# MAGIC Now that Labelbox and the Databricks libraries have been installed, you will need to configure the SDK. You will need an API key that you can create through the app [here](https://app.labelbox.com/account/api-keys). You can also store the key using Databricks Secrets API. The SDK will attempt to use the env var `LABELBOX_API_KEY` - -# COMMAND ---------- - -import labelbox as lb -import labelspark - -API_KEY = "" - -if not (API_KEY): - raise ValueError("Go to Labelbox to get an API key") - -client = lb.Client(API_KEY) - -# COMMAND ---------- - -# MAGIC %md -# MAGIC ## Create seed data -# MAGIC -# MAGIC Next we'll load a demo dataset into a Spark table so you can see how to easily load assets into Labelbox via URLs with the Labelbox Connector for Databricks. -# MAGIC -# MAGIC Also, Labelbox has native support for AWS, Azure, and GCP cloud storage. You can connect Labelbox to your storage via [Delegated Access](https://docs.labelbox.com/docs/iam-delegated-access) and easily load those assets for annotation. For more information, you can watch this [video](https://youtu.be/wlWo6EmPDV4). -# MAGIC -# MAGIC You can also add data to Labelbox [using the Labelbox SDK directly](https://docs.labelbox.com/docs/datasets-datarows). We recommend using the SDK if you have complicated dataset creation requirements (e.g. including metadata with your dataset) which aren't handled by the Labelbox Connector for Databricks. - -# COMMAND ---------- - -sample_dataset_dict = { - "external_id": [ - "sample1.jpg", "sample2.jpg", "sample3.jpg", "sample4.jpg", - "sample5.jpg", "sample6.jpg", "sample7.jpg", "sample8.jpg", - "sample9.jpg", "sample10.jpg" - ], - "row_data": [ - "https://storage.googleapis.com/diagnostics-demo-data/coco/COCO_train2014_000000247422.jpg", - "https://storage.googleapis.com/diagnostics-demo-data/coco/COCO_train2014_000000484849.jpg", - "https://storage.googleapis.com/diagnostics-demo-data/coco/COCO_train2014_000000215782.jpg", - "https://storage.googleapis.com/diagnostics-demo-data/coco/COCO_val2014_000000312024.jpg", - "https://storage.googleapis.com/diagnostics-demo-data/coco/COCO_train2014_000000486139.jpg", - "https://storage.googleapis.com/diagnostics-demo-data/coco/COCO_train2014_000000302713.jpg", - "https://storage.googleapis.com/diagnostics-demo-data/coco/COCO_train2014_000000523272.jpg", - "https://storage.googleapis.com/diagnostics-demo-data/coco/COCO_train2014_000000094514.jpg", - "https://storage.googleapis.com/diagnostics-demo-data/coco/COCO_val2014_000000050578.jpg", - "https://storage.googleapis.com/diagnostics-demo-data/coco/COCO_train2014_000000073727.jpg" - ] -} - -df = pd.DataFrame.from_dict(sample_dataset_dict).to_spark( -) #produces our demo Spark table of datarows for Labelbox - -# COMMAND ---------- - -# can parse the directory and make a Spark table of image URLs -SAMPLE_TABLE = "sample_unstructured_data" - -tblList = spark.catalog.listTables() - -if not any([table.name == SAMPLE_TABLE for table in tblList]): - df.createOrReplaceTempView(SAMPLE_TABLE) - print(f"Registered table: {SAMPLE_TABLE}") - -# COMMAND ---------- - -# MAGIC %md -# MAGIC You should now have a temporary table "sample_unstructured_data" which includes the file names and URLs for some demo images. We're going to use this table with Labelbox using the Labelbox Connector for Databricks! - -# COMMAND ---------- - -display(sqlContext.sql(f"select * from {SAMPLE_TABLE} LIMIT 5")) - -# COMMAND ---------- - -# MAGIC %md -# MAGIC ## Create a Labeling Project -# MAGIC -# MAGIC Projects are where teams create labels. A project is requires a dataset of assets to be labeled and an ontology to configure the labeling interface. -# MAGIC -# MAGIC ### Step 1: Create a dataaset -# MAGIC -# MAGIC The [Labelbox Connector for Databricks](https://pypi.org/project/labelspark/) expects a spark table with two columns; the first column "external_id" and second column "row_data" -# MAGIC -# MAGIC external_id is a filename, like "birds.jpg" or "my_video.mp4" -# MAGIC -# MAGIC row_data is the URL path to the file. Labelbox renders assets locally on your users' machines when they label, so your labeler will need permission to access that asset. -# MAGIC -# MAGIC Example: -# MAGIC -# MAGIC | external_id | row_data | -# MAGIC |-------------|--------------------------------------| -# MAGIC | image1.jpg | https://url_to_your_asset/image1.jpg | -# MAGIC | image2.jpg | https://url_to_your_asset/image2.jpg | -# MAGIC | image3.jpg | https://url_to_your_asset/image3.jpg | - -# COMMAND ---------- - -unstructured_data = spark.table(SAMPLE_TABLE) - -demo_dataset = labelspark.create_dataset(client, - unstructured_data, - name="Databricks Demo Dataset") - -# COMMAND ---------- - -print("Open the dataset in the App") -print(f"https://app.labelbox.com/data/{demo_dataset.uid}") - -# COMMAND ---------- - -# MAGIC %md -# MAGIC ### Step 2: Create a project -# MAGIC -# MAGIC You can use the labelbox SDK to build your ontology (we'll do that next) You can also set your project up entirely through our website at app.labelbox.com. -# MAGIC -# MAGIC Check out our [ontology creation documentation.](https://docs.labelbox.com/docs/configure-ontology) - -# COMMAND ---------- - -# Create a new project -project_demo = client.create_project(name="Labelbox and Databricks Example", - media_type=lb.MediaType.Image) -project_demo.datasets.connect(demo_dataset) # add the dataset to the queue - -ontology = lb.OntologyBuilder() - -tools = [ - lb.Tool(tool=lb.Tool.Type.BBOX, name="Car"), - lb.Tool(tool=lb.Tool.Type.BBOX, name="Flower"), - lb.Tool(tool=lb.Tool.Type.BBOX, name="Fruit"), - lb.Tool(tool=lb.Tool.Type.BBOX, name="Plant"), - lb.Tool(tool=lb.Tool.Type.SEGMENTATION, name="Bird"), - lb.Tool(tool=lb.Tool.Type.SEGMENTATION, name="Person"), - lb.Tool(tool=lb.Tool.Type.SEGMENTATION, name="Dog"), - lb.Tool(tool=lb.Tool.Type.SEGMENTATION, name="Gemstone"), -] -for tool in tools: - ontology.add_tool(tool) - -conditions = ["clear", "overcast", "rain", "other"] - -weather_classification = lb.Classification( - class_type=lb.Classification.Type.RADIO, - instructions="what is the weather?", - options=[lb.Option(value=c) for c in conditions]) -ontology.add_classification(weather_classification) - -# Setup editor -for editor in client.get_labeling_frontends(): - if editor.name == 'Editor': - project_demo.setup(editor, ontology.asdict()) - -print("Project Setup is complete.") - -# COMMAND ---------- - -# MAGIC %md -# MAGIC ### Step 3: Go label data - -# COMMAND ---------- - -print("Open the project to start labeling") -print(f"https://app.labelbox.com/projects/{project_demo.uid}/overview") - -# COMMAND ---------- - -raise ValueError("Go label some data before continuing") - -# COMMAND ---------- - -# MAGIC %md -# MAGIC ##Exporting labels/annotations -# MAGIC -# MAGIC After creating labels in Labelbox you can export them to use in Databricks for model training and analysis. - -# COMMAND ---------- - -LABEL_TABLE = "exported_labels" - -# COMMAND ---------- - -labels_table = labelspark.get_annotations(client, project_demo.uid, spark, sc) -labels_table.createOrReplaceTempView(LABEL_TABLE) -display(labels_table) - -# COMMAND ---------- - -# MAGIC %md -# MAGIC ## Other features of Labelbox -# MAGIC -# MAGIC [Model Assisted Labeling](https://docs.labelbox.com/docs/model-assisted-labeling) -# MAGIC
Once you train a model on your initial set of unstructured data, you can plug that model into Labelbox to support a Model Assisted Labeling workflow. Review the outputs of your model, make corrections, and retrain with ease! You can reduce future labeling costs by >50% by leveraging model assisted labeling. -# MAGIC -# MAGIC MAL -# MAGIC -# MAGIC [Catalog](https://docs.labelbox.com/docs/catalog) -# MAGIC
Once you've created datasets and annotations in Labelbox, you can easily browse your datasets and curate new ones in Catalog. Use your model embeddings to find images by similarity search. -# MAGIC -# MAGIC Catalog -# MAGIC -# MAGIC [Model Diagnostics](https://labelbox.com/product/model-diagnostics) -# MAGIC
Labelbox complements your MLFlow experiment tracking with the ability to easily visualize experiment predictions at scale. Model Diagnostics helps you quickly identify areas where your model is weak so you can collect the right data and refine the next model iteration. -# MAGIC -# MAGIC Diagnostics - -# COMMAND ---------- - -# DBTITLE 1,More Info -# MAGIC %md -# MAGIC While using the Labelbox Connector for Databricks, you will likely use the Labelbox SDK (e.g. for programmatic ontology creation). These resources will help familiarize you with the Labelbox Python SDK: -# MAGIC * [Visit our docs](https://labelbox.com/docs/python-api) to learn how the SDK works -# MAGIC * Checkout our [notebook examples](https://github.com/Labelbox/labelspark/tree/master/notebooks) to follow along with interactive tutorials -# MAGIC * view our [API reference](https://labelbox.com/docs/python-api/api-reference). -# MAGIC -# MAGIC Questions or comments? Reach out to us at [ecosystem+databricks@labelbox.com](mailto:ecosystem+databricks@labelbox.com) - -# COMMAND ---------- - -# MAGIC %md -# MAGIC Copyright Labelbox, Inc. 2022. The source in this notebook is provided subject to the [Labelbox Terms of Service](https://docs.labelbox.com/page/terms-of-service). All included or referenced third party libraries are subject to the licenses set forth below. -# MAGIC -# MAGIC |Library Name|Library license | Library License URL | Library Source URL | -# MAGIC |---|---|---|---| -# MAGIC |Labelbox Python SDK|Apache-2.0 License |https://github.com/Labelbox/labelbox-python/blob/develop/LICENSE|https://github.com/Labelbox/labelbox-python -# MAGIC |Labelbox Connector for Databricks|Apache-2.0 License |https://github.com/Labelbox/labelspark/blob/master/LICENSE|https://github.com/Labelbox/labelspark -# MAGIC |Python|Python Software Foundation (PSF) |https://github.com/python/cpython/blob/master/LICENSE|https://github.com/python/cpython| -# MAGIC |Apache Spark|Apache-2.0 License |https://github.com/apache/spark/blob/master/LICENSE|https://github.com/apache/spark| - ----- -examples/integrations/tlt/README.md -# NVIDIA + Labelbox - -##### Turn any Labelbox bounding box project into a deployed service by following these tutorials - --------- - - -#### labelbox_upload.ipynb -* Download images and prelabels -* Setup a labelbox project -* Upload prelabels to labelbox using MAL -* Clean up the data in labelbox - -#### detectnet_v2_bounding_box.ipynb -* Plug in training data from previous step (or bring your own labelbox project) -* Train a model using TLT. Compare with a non-pretrained model -* Prune the model for more efficient deployment -* Convert the model to a TRT engine -* Deploy the model using Triton Inference Server - - ----- -examples/integrations/detectron2/README.md -![Logo](images/detectron-logo.png) - -Detectron2 is Facebook AI Research's next generation library that provides state-of-the-art detection and segmentation -algorithms. Check out the official repository [here](https://github.com/facebookresearch/detectron2) - - -
- -
- -# Getting Started - -The Labelbox team has created two notebooks to help you train your own Detectron2 model with data you have annotated on -Labelbox. - -| Notebook | Github | Google Colab | -| --------------------------- | --------------------------------- | ------------ | -| Object Detection | [Github](coco_object.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/integrations/detectron2/coco_panoptic.ipynb) | -| Panoptic Detection | [Github](coco_panoptic.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/integrations/detectron2/coco_panoptic.ipynb) ------- - - ----- -examples/integrations/detectron2/coco_utils.py -from concurrent.futures import ThreadPoolExecutor, as_completed -import functools -import random -import json -import os - -import numpy as np -from PIL import Image -from tqdm import tqdm -import cv2 -from detectron2.utils.visualizer import Visualizer -from panopticapi.utils import rgb2id - - -def get_annotations(images, all_annotations): - image_lookup = {image['id'] for image in images} - return [ - annot for annot in all_annotations if annot['image_id'] in image_lookup - ] - - -def partition_indices(total_n, splits): - - if splits is None: - raise ValueError("") - - if sum(splits) != 1.: - raise ValueError(f"Found {sum(splits)}. Expected 1.") - - splits = np.cumsum(splits) - for idx in range(len((splits))): - start = 0 if idx == 0 else int(total_n * splits[idx - 1]) - end = int(splits[idx] * total_n) - yield start, end - - -def partition_coco(coco_instance_data, coco_panoptic_data=None, splits=None): - images = coco_instance_data['images'] - n_classes = len( - {category['id'] for category in coco_instance_data['categories']}) - random.shuffle(images) - partitions = [] - for start, end in partition_indices(len(images), splits): - partition = { - 'instance': - dict(categories=coco_instance_data['categories'], - images=images[start:end], - annotations=get_annotations( - images[start:end], coco_instance_data['annotations'])) - } - if coco_panoptic_data is not None: - partition['panoptic'] = dict( - categories=coco_panoptic_data['categories'], - images=images[start:end], - annotations=get_annotations(images[start:end], - coco_panoptic_data['annotations'])) - partitions.append(partition) - return partitions - - -def visualize_object_inferences(metadata_catalog, - coco_examples, - predictor, - scale=1.0, - max_images=5, - resize_dims=(768, 512)): - images = [] - for idx, example in enumerate(coco_examples): - if idx > max_images: - break - im = cv2.imread(example['file_name']) - outputs = predictor(im) - v = Visualizer(im[:, :, ::-1], metadata_catalog, scale=scale) - out = v.draw_instance_predictions(outputs["instances"].to("cpu")) - images.append(cv2.resize(out.get_image()[:, :, ::-1], resize_dims)) - return Image.fromarray(np.vstack(images)) - - -def visualize_coco_examples(metadata_catalog, - object_examples, - panoptic_examples=None, - scale=1.0, - max_images=5, - resize_dims=(768, 512)): - if panoptic_examples is not None: - lookup = {d['file_name']: d for d in panoptic_examples} - - images = [] - for idx, example in enumerate(object_examples): - if idx > max_images: - break - im = cv2.imread(example['file_name']) - v = Visualizer(im[:, :, ::-1], metadata_catalog, scale=scale) - out = v.draw_dataset_dict(example) - if panoptic_examples is not None: - example_panoptic = lookup.get(example['file_name']) - if example_panoptic is not None: - out = v.draw_dataset_dict(example_panoptic) - images.append(cv2.resize(out.get_image(), resize_dims)) - return Image.fromarray(np.vstack(images)) - - -def _process_panoptic_to_semantic(input_panoptic, output_semantic, segments, - id_map): - panoptic = np.asarray(Image.open(input_panoptic), dtype=np.uint32) - panoptic = rgb2id(panoptic) - - output = np.zeros_like(panoptic, dtype=np.uint8) - for seg in segments: - cat_id = seg["category_id"] - new_cat_id = id_map[cat_id] - output[panoptic == seg["id"]] = new_cat_id - Image.fromarray(output).save(output_semantic) - - -def separate_coco_semantic_from_panoptic(panoptic_json, panoptic_root, - sem_seg_root, categories): - """ - Create semantic segmentation annotations from panoptic segmentation - annotations, to be used by PanopticFPN. - - It maps all thing categories to class 0, and maps all unlabeled pixels to class 255. - It maps all stuff categories to contiguous ids starting from 1. - - Args: - panoptic_json (str): path to the panoptic json file, in COCO's format. - panoptic_root (str): a directory with panoptic annotation files, in COCO's format. - sem_seg_root (str): a directory to output semantic annotation files - categories (list[dict]): category metadata. Each dict needs to have: - "id": corresponds to the "category_id" in the json annotations - "isthing": 0 or 1 - """ - os.makedirs(sem_seg_root, exist_ok=True) - - stuff_ids = [k["id"] for k in categories if k["isthing"] == 0] - thing_ids = [k["id"] for k in categories if k["isthing"] == 1] - id_map = {} # map from category id to id in the output semantic annotation - assert len(stuff_ids) <= 254 - for i, stuff_id in enumerate(stuff_ids): - id_map[stuff_id] = i + 1 - for thing_id in thing_ids: - id_map[thing_id] = 0 - id_map[0] = 255 - - with open(panoptic_json) as f: - obj = json.load(f) - - def iter_annotations(): - for anno in obj["annotations"]: - file_name = anno["file_name"] - segments = anno["segments_info"] - input = os.path.join(panoptic_root, file_name) - output = os.path.join(sem_seg_root, file_name) - yield input, output, segments - - fn = functools.partial(_process_panoptic_to_semantic, id_map=id_map) - futures = [] - with ThreadPoolExecutor(max_workers=12) as executor: - for args in iter_annotations(): - futures.append(executor.submit(fn, *args)) - for _ in tqdm(as_completed(futures)): - _.result() - ----- -examples/scripts/upload_documentation.py -import glob -import json -import os - -import click -import nbformat -import requests -from nbconvert import MarkdownExporter - -README_AUTH = os.getenv('README_AUTH') -README_ENDPOINT = "https://dash.readme.com/api/v1/docs" -README_DOC_ENDPOINT = "https://dash.readme.com/api/v1/docs/" -CATEGORY_ID = '61fb645198ad91004246bd5f' -CATEGORY_SLUG = 'tutorials' - - -def upload_doc(path, section): - title = path.split('/')[-1].replace(".ipynb", - '').capitalize().replace('_', ' ') - - with open(path) as fb: - nb = nbformat.reads(json.dumps(json.load(fb)), as_version=4) - - nb.cells = [nb.cells[1]] + nb.cells[3:] - nb.cells[0]['source'] += '\n' - exporter = MarkdownExporter() - - body, resources = exporter.from_notebook_node(nb) - - payload = { - "hidden": - True, - "title": - f'{section["slug"]}-' + title.replace(' ', '-').replace("(", ''), - "category": - CATEGORY_ID, - "parentDoc": - section['id'], - "body": - body - } - - headers = { - "Accept": "application/json", - "Content-Type": "application/json", - "Authorization": README_AUTH - } - - response = requests.post(README_ENDPOINT, json=payload, headers=headers) - response.raise_for_status() - data = response.json() - change_name(data['slug'], title, headers, hidden=False) - - -def make_sections(sections): - for section in sections: - print(section) - payload = { - "hidden": True, - "order": section['order'], - "title": 'section-' + section['dir'] + '-notebooks', - "category": CATEGORY_ID - } - headers = { - "Accept": "application/json", - "Content-Type": "application/json", - "Authorization": README_AUTH - } - - response = requests.post(README_ENDPOINT, json=payload, headers=headers) - data = response.json() - - section['id'] = data['id'] - section['slug'] = data['slug'] - - change_name(data["slug"], section['title'], headers, False) - - return sections - - -def change_name(slug, title, headers, hidden=True): - resp = requests.put(f'{README_DOC_ENDPOINT}/{slug}', - json={ - "hidden": hidden, - "title": title, - "category": CATEGORY_ID - }, - headers=headers) - resp.raise_for_status() - - -def erase_category_docs(cat_slug): - headers = {"Accept": "application/json", "Authorization": README_AUTH} - - response = requests.request( - "GET", - f'https://dash.readme.com/api/v1/categories/{cat_slug}/docs', - headers=headers) - docs = response.json() - for doc in docs: - for child in doc["children"]: - resp = requests.delete(f'{README_DOC_ENDPOINT}/{child["slug"]}', - headers=headers) - resp = requests.delete(f'{README_DOC_ENDPOINT}/{doc["slug"]}', - headers=headers) - - -@click.command() -@click.option('--config-path') -# @click.option('--output-path') -def main(config_path): - # print(input_path) - erase_category_docs(CATEGORY_SLUG) - with open(config_path) as fb: - config = json.load(fb) - config = make_sections(config) - - for section in config: - print(section, '\n------') - for path in glob.glob(f'{section["dir"]}/**.ipynb'): - print('*', path) - upload_doc(path, section) - print('-------') - - -if __name__ == '__main__': - main() - ----- -labelbox/typing_imports.py -""" -This module imports types that differ across python versions, so other modules -don't have to worry about where they should be imported from. -""" - -import sys -if sys.version_info >= (3, 8): - from typing import Literal -else: - from typing_extensions import Literal ----- -labelbox/client.py -# type: ignore -import json -import logging -import mimetypes -import os -import random -import sys -import time -import urllib.parse -from collections import defaultdict -from datetime import datetime, timezone -from typing import Any, List, Dict, Union, Optional - -import requests -import requests.exceptions -from google.api_core import retry - -import labelbox.exceptions -from labelbox import __version__ as SDK_VERSION -from labelbox import utils -from labelbox.orm import query -from labelbox.orm.db_object import DbObject -from labelbox.orm.model import Entity -from labelbox.pagination import PaginatedCollection -from labelbox.schema import role -from labelbox.schema.conflict_resolution_strategy import ConflictResolutionStrategy -from labelbox.schema.data_row import DataRow -from labelbox.schema.catalog import Catalog -from labelbox.schema.data_row_metadata import DataRowMetadataOntology -from labelbox.schema.dataset import Dataset -from labelbox.schema.enums import CollectionJobStatus -from labelbox.schema.foundry.foundry_client import FoundryClient -from labelbox.schema.iam_integration import IAMIntegration -from labelbox.schema.identifiables import DataRowIds -from labelbox.schema.identifiables import GlobalKeys -from labelbox.schema.labeling_frontend import LabelingFrontend -from labelbox.schema.media_type import MediaType, get_media_type_validation_error -from labelbox.schema.model import Model -from labelbox.schema.model_run import ModelRun -from labelbox.schema.ontology import Ontology, DeleteFeatureFromOntologyResult -from labelbox.schema.ontology import Tool, Classification, FeatureSchema -from labelbox.schema.organization import Organization -from labelbox.schema.project import Project -from labelbox.schema.quality_mode import QualityMode, BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS, \ - BENCHMARK_AUTO_AUDIT_PERCENTAGE, CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS, CONSENSUS_AUTO_AUDIT_PERCENTAGE -from labelbox.schema.queue_mode import QueueMode -from labelbox.schema.role import Role -from labelbox.schema.send_to_annotate_params import SendToAnnotateFromCatalogParams, build_destination_task_queue_input, \ - build_predictions_input, build_annotations_input -from labelbox.schema.slice import CatalogSlice, ModelSlice -from labelbox.schema.task import Task -from labelbox.schema.user import User - -logger = logging.getLogger(__name__) - -_LABELBOX_API_KEY = "LABELBOX_API_KEY" - - -def python_version_info(): - version_info = sys.version_info - - return f"{version_info.major}.{version_info.minor}.{version_info.micro}-{version_info.releaselevel}" - - -class Client: - """ A Labelbox client. - - Contains info necessary for connecting to a Labelbox server (URL, - authentication key). Provides functions for querying and creating - top-level data objects (Projects, Datasets). - """ - - def __init__(self, - api_key=None, - endpoint='https://api.labelbox.com/graphql', - enable_experimental=False, - app_url="https://app.labelbox.com", - rest_endpoint="https://api.labelbox.com/api/v1"): - """ Creates and initializes a Labelbox Client. - - Logging is defaulted to level WARNING. To receive more verbose - output to console, update `logging.level` to the appropriate level. - - >>> logging.basicConfig(level = logging.INFO) - >>> client = Client("") - - Args: - api_key (str): API key. If None, the key is obtained from the "LABELBOX_API_KEY" environment variable. - endpoint (str): URL of the Labelbox server to connect to. - enable_experimental (bool): Indicates whether or not to use experimental features - app_url (str) : host url for all links to the web app - Raises: - labelbox.exceptions.AuthenticationError: If no `api_key` - is provided as an argument or via the environment - variable. - """ - if api_key is None: - if _LABELBOX_API_KEY not in os.environ: - raise labelbox.exceptions.AuthenticationError( - "Labelbox API key not provided") - api_key = os.environ[_LABELBOX_API_KEY] - self.api_key = api_key - - self.enable_experimental = enable_experimental - if enable_experimental: - logger.info("Experimental features have been enabled") - - logger.info("Initializing Labelbox client at '%s'", endpoint) - self.app_url = app_url - self.endpoint = endpoint - self.rest_endpoint = rest_endpoint - - self.headers = { - 'Accept': 'application/json', - 'Content-Type': 'application/json', - 'Authorization': 'Bearer %s' % api_key, - 'X-User-Agent': f"python-sdk {SDK_VERSION}", - 'X-Python-Version': f"{python_version_info()}", - } - self._data_row_metadata_ontology = None - - @retry.Retry(predicate=retry.if_exception_type( - labelbox.exceptions.InternalServerError, - labelbox.exceptions.TimeoutError)) - def execute(self, - query=None, - params=None, - data=None, - files=None, - timeout=60.0, - experimental=False, - error_log_key="message"): - """ Sends a request to the server for the execution of the - given query. - - Checks the response for errors and wraps errors - in appropriate `labelbox.exceptions.LabelboxError` subtypes. - - Args: - query (str): The query to execute. - params (dict): Query parameters referenced within the query. - data (str): json string containing the query to execute - files (dict): file arguments for request - timeout (float): Max allowed time for query execution, - in seconds. - Returns: - dict, parsed JSON response. - Raises: - labelbox.exceptions.AuthenticationError: If authentication - failed. - labelbox.exceptions.InvalidQueryError: If `query` is not - syntactically or semantically valid (checked server-side). - labelbox.exceptions.ApiLimitError: If the server API limit was - exceeded. See "How to import data" in the online documentation - to see API limits. - labelbox.exceptions.TimeoutError: If response was not received - in `timeout` seconds. - labelbox.exceptions.NetworkError: If an unknown error occurred - most likely due to connection issues. - labelbox.exceptions.LabelboxError: If an unknown error of any - kind occurred. - ValueError: If query and data are both None. - """ - logger.debug("Query: %s, params: %r, data %r", query, params, data) - - # Convert datetimes to UTC strings. - def convert_value(value): - if isinstance(value, datetime): - value = value.astimezone(timezone.utc) - value = value.strftime("%Y-%m-%dT%H:%M:%SZ") - return value - - if query is not None: - if params is not None: - params = { - key: convert_value(value) for key, value in params.items() - } - data = json.dumps({ - 'query': query, - 'variables': params - }).encode('utf-8') - elif data is None: - raise ValueError("query and data cannot both be none") - - endpoint = self.endpoint if not experimental else self.endpoint.replace( - "/graphql", "/_gql") - - try: - request = { - 'url': endpoint, - 'data': data, - 'headers': self.headers, - 'timeout': timeout - } - if files: - request.update({'files': files}) - request['headers'] = { - 'Authorization': self.headers['Authorization'] - } - - response = requests.post(**request) - logger.debug("Response: %s", response.text) - except requests.exceptions.Timeout as e: - raise labelbox.exceptions.TimeoutError(str(e)) - except requests.exceptions.RequestException as e: - logger.error("Unknown error: %s", str(e)) - raise labelbox.exceptions.NetworkError(e) - except Exception as e: - raise labelbox.exceptions.LabelboxError( - "Unknown error during Client.query(): " + str(e), e) - try: - r_json = response.json() - except: - if "upstream connect error or disconnect/reset before headers" \ - in response.text: - raise labelbox.exceptions.InternalServerError( - "Connection reset") - elif response.status_code == 502: - error_502 = '502 Bad Gateway' - raise labelbox.exceptions.InternalServerError(error_502) - - raise labelbox.exceptions.LabelboxError( - "Failed to parse response as JSON: %s" % response.text) - - errors = r_json.get("errors", []) - - def check_errors(keywords, *path): - """ Helper that looks for any of the given `keywords` in any of - current errors on paths (like error[path][component][to][keyword]). - """ - for error in errors: - obj = error - for path_elem in path: - obj = obj.get(path_elem, {}) - if obj in keywords: - return error - return None - - def get_error_status_code(error: dict) -> int: - try: - return int(error["extensions"].get("exception").get("status")) - except: - return 500 - - if check_errors(["AUTHENTICATION_ERROR"], "extensions", - "code") is not None: - raise labelbox.exceptions.AuthenticationError("Invalid API key") - - authorization_error = check_errors(["AUTHORIZATION_ERROR"], - "extensions", "code") - if authorization_error is not None: - raise labelbox.exceptions.AuthorizationError( - authorization_error["message"]) - - validation_error = check_errors(["GRAPHQL_VALIDATION_FAILED"], - "extensions", "code") - - if validation_error is not None: - message = validation_error["message"] - if message == "Query complexity limit exceeded": - raise labelbox.exceptions.ValidationFailedError(message) - else: - raise labelbox.exceptions.InvalidQueryError(message) - - graphql_error = check_errors(["GRAPHQL_PARSE_FAILED"], "extensions", - "code") - if graphql_error is not None: - raise labelbox.exceptions.InvalidQueryError( - graphql_error["message"]) - - # Check if API limit was exceeded - response_msg = r_json.get("message", "") - - if response_msg.startswith("You have exceeded"): - raise labelbox.exceptions.ApiLimitError(response_msg) - - resource_not_found_error = check_errors(["RESOURCE_NOT_FOUND"], - "extensions", "code") - if resource_not_found_error is not None: - # Return None and let the caller methods raise an exception - # as they already know which resource type and ID was requested - return None - - resource_conflict_error = check_errors(["RESOURCE_CONFLICT"], - "extensions", "code") - if resource_conflict_error is not None: - raise labelbox.exceptions.ResourceConflict( - resource_conflict_error["message"]) - - malformed_request_error = check_errors(["MALFORMED_REQUEST"], - "extensions", "code") - if malformed_request_error is not None: - raise labelbox.exceptions.MalformedQueryException( - malformed_request_error[error_log_key]) - - # A lot of different error situations are now labeled serverside - # as INTERNAL_SERVER_ERROR, when they are actually client errors. - # TODO: fix this in the server API - internal_server_error = check_errors(["INTERNAL_SERVER_ERROR"], - "extensions", "code") - if internal_server_error is not None: - message = internal_server_error.get("message") - error_status_code = get_error_status_code(internal_server_error) - - if error_status_code == 400: - raise labelbox.exceptions.InvalidQueryError(message) - elif error_status_code == 426: - raise labelbox.exceptions.OperationNotAllowedException(message) - elif error_status_code == 500: - raise labelbox.exceptions.LabelboxError(message) - else: - raise labelbox.exceptions.InternalServerError(message) - - not_allowed_error = check_errors(["OPERATION_NOT_ALLOWED"], - "extensions", "code") - if not_allowed_error is not None: - message = not_allowed_error.get("message") - raise labelbox.exceptions.OperationNotAllowedException(message) - - if len(errors) > 0: - logger.warning("Unparsed errors on query execution: %r", errors) - messages = list( - map( - lambda x: { - "message": x["message"], - "code": x["extensions"]["code"] - }, errors)) - raise labelbox.exceptions.LabelboxError("Unknown error: %s" % - str(messages)) - - # if we do return a proper error code, and didn't catch this above - # reraise - # this mainly catches a 401 for API access disabled for free tier - # TODO: need to unify API errors to handle things more uniformly - # in the SDK - if response.status_code != requests.codes.ok: - message = f"{response.status_code} {response.reason}" - cause = r_json.get('message') - raise labelbox.exceptions.LabelboxError(message, cause) - - return r_json["data"] - - def upload_file(self, path: str) -> str: - """Uploads given path to local file. - - Also includes best guess at the content type of the file. - - Args: - path (str): path to local file to be uploaded. - Returns: - str, the URL of uploaded data. - Raises: - labelbox.exceptions.LabelboxError: If upload failed. - """ - content_type, _ = mimetypes.guess_type(path) - filename = os.path.basename(path) - with open(path, "rb") as f: - return self.upload_data(content=f.read(), - filename=filename, - content_type=content_type) - - @retry.Retry(predicate=retry.if_exception_type( - labelbox.exceptions.InternalServerError)) - def upload_data(self, - content: bytes, - filename: str = None, - content_type: str = None, - sign: bool = False) -> str: - """ Uploads the given data (bytes) to Labelbox. - - Args: - content: bytestring to upload - filename: name of the upload - content_type: content type of data uploaded - sign: whether or not to sign the url - - Returns: - str, the URL of uploaded data. - - Raises: - labelbox.exceptions.LabelboxError: If upload failed. - """ - - request_data = { - "operations": - json.dumps({ - "variables": { - "file": None, - "contentLength": len(content), - "sign": sign - }, - "query": - """mutation UploadFile($file: Upload!, $contentLength: Int!, - $sign: Boolean) { - uploadFile(file: $file, contentLength: $contentLength, - sign: $sign) {url filename} } """, - }), - "map": (None, json.dumps({"1": ["variables.file"]})), - } - response = requests.post( - self.endpoint, - headers={"authorization": "Bearer %s" % self.api_key}, - data=request_data, - files={ - "1": (filename, content, content_type) if - (filename and content_type) else content - }) - - if response.status_code == 502: - error_502 = '502 Bad Gateway' - raise labelbox.exceptions.InternalServerError(error_502) - elif response.status_code == 503: - raise labelbox.exceptions.InternalServerError(response.text) - elif response.status_code == 520: - raise labelbox.exceptions.InternalServerError(response.text) - - try: - file_data = response.json().get("data", None) - except ValueError as e: # response is not valid JSON - raise labelbox.exceptions.LabelboxError( - "Failed to upload, unknown cause", e) - - if not file_data or not file_data.get("uploadFile", None): - try: - errors = response.json().get("errors", []) - error_msg = next(iter(errors), {}).get("message", - "Unknown error") - except Exception as e: - error_msg = "Unknown error" - raise labelbox.exceptions.LabelboxError( - "Failed to upload, message: %s" % error_msg) - - return file_data["uploadFile"]["url"] - - def _get_single(self, db_object_type, uid): - """ Fetches a single object of the given type, for the given ID. - - Args: - db_object_type (type): DbObject subclass. - uid (str): Unique ID of the row. - Returns: - Object of `db_object_type`. - Raises: - labelbox.exceptions.ResourceNotFoundError: If there is no object - of the given type for the given ID. - """ - query_str, params = query.get_single(db_object_type, uid) - - res = self.execute(query_str, params) - res = res and res.get(utils.camel_case(db_object_type.type_name())) - if res is None: - raise labelbox.exceptions.ResourceNotFoundError( - db_object_type, params) - else: - return db_object_type(self, res) - - def get_project(self, project_id) -> Project: - """ Gets a single Project with the given ID. - - >>> project = client.get_project("") - - Args: - project_id (str): Unique ID of the Project. - Returns: - The sought Project. - Raises: - labelbox.exceptions.ResourceNotFoundError: If there is no - Project with the given ID. - """ - return self._get_single(Entity.Project, project_id) - - def get_dataset(self, dataset_id) -> Dataset: - """ Gets a single Dataset with the given ID. - - >>> dataset = client.get_dataset("") - - Args: - dataset_id (str): Unique ID of the Dataset. - Returns: - The sought Dataset. - Raises: - labelbox.exceptions.ResourceNotFoundError: If there is no - Dataset with the given ID. - """ - return self._get_single(Entity.Dataset, dataset_id) - - def get_user(self) -> User: - """ Gets the current User database object. - - >>> user = client.get_user() - """ - return self._get_single(Entity.User, None) - - def get_organization(self) -> Organization: - """ Gets the Organization DB object of the current user. - - >>> organization = client.get_organization() - """ - return self._get_single(Entity.Organization, None) - - def _get_all(self, db_object_type, where, filter_deleted=True): - """ Fetches all the objects of the given type the user has access to. - - Args: - db_object_type (type): DbObject subclass. - where (Comparison, LogicalOperation or None): The `where` clause - for filtering. - Returns: - An iterable of `db_object_type` instances. - """ - if filter_deleted: - not_deleted = db_object_type.deleted == False - where = not_deleted if where is None else where & not_deleted - query_str, params = query.get_all(db_object_type, where) - - return PaginatedCollection( - self, query_str, params, - [utils.camel_case(db_object_type.type_name()) + "s"], - db_object_type) - - def get_projects(self, where=None) -> PaginatedCollection: - """ Fetches all the projects the user has access to. - - >>> projects = client.get_projects(where=(Project.name == "") & (Project.description == "")) - - Args: - where (Comparison, LogicalOperation or None): The `where` clause - for filtering. - Returns: - PaginatedCollection of all projects the user has access to or projects matching the criteria specified. - """ - return self._get_all(Entity.Project, where) - - def get_datasets(self, where=None) -> PaginatedCollection: - """ Fetches one or more datasets. - - >>> datasets = client.get_datasets(where=(Dataset.name == "") & (Dataset.description == "")) - - Args: - where (Comparison, LogicalOperation or None): The `where` clause - for filtering. - Returns: - PaginatedCollection of all datasets the user has access to or datasets matching the criteria specified. - """ - return self._get_all(Entity.Dataset, where) - - def get_labeling_frontends(self, where=None) -> List[LabelingFrontend]: - """ Fetches all the labeling frontends. - - >>> frontend = client.get_labeling_frontends(where=LabelingFrontend.name == "Editor") - - Args: - where (Comparison, LogicalOperation or None): The `where` clause - for filtering. - Returns: - An iterable of LabelingFrontends (typically a PaginatedCollection). - """ - return self._get_all(Entity.LabelingFrontend, where) - - def _create(self, db_object_type, data): - """ Creates an object on the server. Attribute values are - passed as keyword arguments: - - Args: - db_object_type (type): A DbObjectType subtype. - data (dict): Keys are attributes or their names (in Python, - snake-case convention) and values are desired attribute values. - Returns: - A new object of the given DB object type. - Raises: - InvalidAttributeError: If the DB object type does not contain - any of the attribute names given in `data`. - """ - # Convert string attribute names to Field or Relationship objects. - # Also convert Labelbox object values to their UIDs. - data = { - db_object_type.attribute(attr) if isinstance(attr, str) else attr: - value.uid if isinstance(value, DbObject) else value - for attr, value in data.items() - } - - query_string, params = query.create(db_object_type, data) - res = self.execute(query_string, params) - res = res["create%s" % db_object_type.type_name()] - return db_object_type(self, res) - - def create_dataset(self, - iam_integration=IAMIntegration._DEFAULT, - **kwargs) -> Dataset: - """ Creates a Dataset object on the server. - - Attribute values are passed as keyword arguments. - - Args: - iam_integration (IAMIntegration) : Uses the default integration. - Optionally specify another integration or set as None to not use delegated access - **kwargs: Keyword arguments with Dataset attribute values. - Returns: - A new Dataset object. - Raises: - InvalidAttributeError: If the Dataset type does not contain - any of the attribute names given in kwargs. - Examples: - Create a dataset - >>> dataset = client.create_dataset(name="") - Create a dataset with description - >>> dataset = client.create_dataset(name="", description="") - """ - dataset = self._create(Entity.Dataset, kwargs) - - if iam_integration == IAMIntegration._DEFAULT: - iam_integration = self.get_organization( - ).get_default_iam_integration() - - if iam_integration is None: - return dataset - - try: - if not isinstance(iam_integration, IAMIntegration): - raise TypeError( - f"iam integration must be a reference an `IAMIntegration` object. Found {type(iam_integration)}" - ) - - if not iam_integration.valid: - raise ValueError( - "Integration is not valid. Please select another.") - - self.execute( - """mutation setSignerForDatasetPyApi($signerId: ID!, $datasetId: ID!) { - setSignerForDataset(data: { signerId: $signerId}, where: {id: $datasetId}){id}} - """, { - 'signerId': iam_integration.uid, - 'datasetId': dataset.uid - }) - validation_result = self.execute( - """mutation validateDatasetPyApi($id: ID!){validateDataset(where: {id : $id}){ - valid checks{name, success}}} - """, {'id': dataset.uid}) - - if not validation_result['validateDataset']['valid']: - raise labelbox.exceptions.LabelboxError( - f"IAMIntegration was not successfully added to the dataset." - ) - except Exception as e: - dataset.delete() - raise e - return dataset - - def create_project(self, **kwargs) -> Project: - """ Creates a Project object on the server. - - Attribute values are passed as keyword arguments. - - >>> project = client.create_project( - name="", - description="", - media_type=MediaType.Image, - queue_mode=QueueMode.Batch - ) - - Args: - name (str): A name for the project - description (str): A short summary for the project - media_type (MediaType): The type of assets that this project will accept - queue_mode (Optional[QueueMode]): The queue mode to use - quality_mode (Optional[QualityMode]): The quality mode to use (e.g. Benchmark, Consensus). Defaults to - Benchmark - Returns: - A new Project object. - Raises: - InvalidAttributeError: If the Project type does not contain - any of the attribute names given in kwargs. - """ - - auto_audit_percentage = kwargs.get("auto_audit_percentage") - auto_audit_number_of_labels = kwargs.get("auto_audit_number_of_labels") - if auto_audit_percentage is not None or auto_audit_number_of_labels is not None: - raise ValueError( - "quality_mode must be set instead of auto_audit_percentage or auto_audit_number_of_labels." - ) - - name = kwargs.get("name") - if name is None or not name.strip(): - raise ValueError("project name must be a valid string.") - - queue_mode = kwargs.get("queue_mode") - if queue_mode is QueueMode.Dataset: - raise ValueError( - "Dataset queue mode is deprecated. Please prefer Batch queue mode." - ) - elif queue_mode is QueueMode.Batch: - logger.warning( - "Passing a queue mode of batch is redundant and will soon no longer be supported." - ) - - media_type = kwargs.get("media_type") - if media_type: - if MediaType.is_supported(media_type): - media_type = media_type.value - else: - raise TypeError(f"{media_type} is not a valid media type. Use" - f" any of {MediaType.get_supported_members()}" - " from MediaType. Example: MediaType.Image.") - else: - logger.warning( - "Creating a project without specifying media_type" - " through this method will soon no longer be supported.") - - quality_mode = kwargs.get("quality_mode") - if not quality_mode: - logger.info("Defaulting quality mode to Benchmark.") - - data = kwargs - data.pop("quality_mode", None) - if quality_mode is None or quality_mode is QualityMode.Benchmark: - data[ - "auto_audit_number_of_labels"] = BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS - data["auto_audit_percentage"] = BENCHMARK_AUTO_AUDIT_PERCENTAGE - elif quality_mode is QualityMode.Consensus: - data[ - "auto_audit_number_of_labels"] = CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS - data["auto_audit_percentage"] = CONSENSUS_AUTO_AUDIT_PERCENTAGE - else: - raise ValueError(f"{quality_mode} is not a valid quality mode.") - - return self._create(Entity.Project, { - **data, - **({ - "media_type": media_type - } if media_type else {}) - }) - - def get_roles(self) -> List[Role]: - """ - Returns: - Roles: Provides information on available roles within an organization. - Roles are used for user management. - """ - return role.get_roles(self) - - def get_data_row(self, data_row_id): - """ - - Returns: - DataRow: returns a single data row given the data row id - """ - - return self._get_single(Entity.DataRow, data_row_id) - - def get_data_row_by_global_key(self, global_key: str) -> DataRow: - """ - Returns: DataRow: returns a single data row given the global key - """ - - res = self.get_data_row_ids_for_global_keys([global_key]) - if res['status'] != "SUCCESS": - raise labelbox.exceptions.MalformedQueryException(res['errors'][0]) - if len(res['results']) == 0: - raise labelbox.exceptions.ResourceNotFoundError( - Entity.DataRow, {global_key: global_key}) - data_row_id = res['results'][0] - - return self.get_data_row(data_row_id) - - def get_data_row_metadata_ontology(self) -> DataRowMetadataOntology: - """ - - Returns: - DataRowMetadataOntology: The ontology for Data Row Metadata for an organization - - """ - if self._data_row_metadata_ontology is None: - self._data_row_metadata_ontology = DataRowMetadataOntology(self) - return self._data_row_metadata_ontology - - def get_model(self, model_id) -> Model: - """ Gets a single Model with the given ID. - - >>> model = client.get_model("") - - Args: - model_id (str): Unique ID of the Model. - Returns: - The sought Model. - Raises: - labelbox.exceptions.ResourceNotFoundError: If there is no - Model with the given ID. - """ - return self._get_single(Entity.Model, model_id) - - def get_models(self, where=None) -> List[Model]: - """ Fetches all the models the user has access to. - - >>> models = client.get_models(where=(Model.name == "")) - - Args: - where (Comparison, LogicalOperation or None): The `where` clause - for filtering. - Returns: - An iterable of Models (typically a PaginatedCollection). - """ - return self._get_all(Entity.Model, where, filter_deleted=False) - - def create_model(self, name, ontology_id) -> Model: - """ Creates a Model object on the server. - - >>> model = client.create_model(, ) - - Args: - name (string): Name of the model - ontology_id (string): ID of the related ontology - Returns: - A new Model object. - Raises: - InvalidAttributeError: If the Model type does not contain - any of the attribute names given in kwargs. - """ - query_str = """mutation createModelPyApi($name: String!, $ontologyId: ID!){ - createModel(data: {name : $name, ontologyId : $ontologyId}){ - %s - } - }""" % query.results_query_part(Entity.Model) - - result = self.execute(query_str, { - "name": name, - "ontologyId": ontology_id - }) - return Entity.Model(self, result['createModel']) - - def get_data_row_ids_for_external_ids( - self, external_ids: List[str]) -> Dict[str, List[str]]: - """ - Returns a list of data row ids for a list of external ids. - There is a max of 1500 items returned at a time. - - Args: - external_ids: List of external ids to fetch data row ids for - Returns: - A dict of external ids as keys and values as a list of data row ids that correspond to that external id. - """ - query_str = """query externalIdsToDataRowIdsPyApi($externalId_in: [String!]!){ - externalIdsToDataRowIds(externalId_in: $externalId_in) { dataRowId externalId } - } - """ - max_ids_per_request = 100 - result = defaultdict(list) - for i in range(0, len(external_ids), max_ids_per_request): - for row in self.execute( - query_str, - {'externalId_in': external_ids[i:i + max_ids_per_request] - })['externalIdsToDataRowIds']: - result[row['externalId']].append(row['dataRowId']) - return result - - def get_ontology(self, ontology_id) -> Ontology: - """ - Fetches an Ontology by id. - - Args: - ontology_id (str): The id of the ontology to query for - Returns: - Ontology - """ - return self._get_single(Entity.Ontology, ontology_id) - - def get_ontologies(self, name_contains) -> PaginatedCollection: - """ - Fetches all ontologies with names that match the name_contains string. - - Args: - name_contains (str): the string to search ontology names by - Returns: - PaginatedCollection of Ontologies with names that match `name_contains` - """ - query_str = """query getOntologiesPyApi($search: String, $filter: OntologyFilter, $from : String, $first: PageSize){ - ontologies(where: {filter: $filter, search: $search}, after: $from, first: $first){ - nodes {%s} - nextCursor - } - } - """ % query.results_query_part(Entity.Ontology) - params = {'search': name_contains, 'filter': {'status': 'ALL'}} - return PaginatedCollection(self, query_str, params, - ['ontologies', 'nodes'], Entity.Ontology, - ['ontologies', 'nextCursor']) - - def get_feature_schema(self, feature_schema_id): - """ - Fetches a feature schema. Only supports top level feature schemas. - - Args: - feature_schema_id (str): The id of the feature schema to query for - Returns: - FeatureSchema - """ - - query_str = """query rootSchemaNodePyApi($rootSchemaNodeWhere: RootSchemaNodeWhere!){ - rootSchemaNode(where: $rootSchemaNodeWhere){%s} - }""" % query.results_query_part(Entity.FeatureSchema) - res = self.execute( - query_str, - {'rootSchemaNodeWhere': { - 'featureSchemaId': feature_schema_id - }})['rootSchemaNode'] - res['id'] = res['normalized']['featureSchemaId'] - return Entity.FeatureSchema(self, res) - - def get_feature_schemas(self, name_contains) -> PaginatedCollection: - """ - Fetches top level feature schemas with names that match the `name_contains` string - - Args: - name_contains (str): search filter for a name of a root feature schema - If present, results in a case insensitive 'like' search for feature schemas - If None, returns all top level feature schemas - Returns: - PaginatedCollection of FeatureSchemas with names that match `name_contains` - """ - query_str = """query rootSchemaNodesPyApi($search: String, $filter: RootSchemaNodeFilter, $from : String, $first: PageSize){ - rootSchemaNodes(where: {filter: $filter, search: $search}, after: $from, first: $first){ - nodes {%s} - nextCursor - } - } - """ % query.results_query_part(Entity.FeatureSchema) - params = {'search': name_contains, 'filter': {'status': 'ALL'}} - - def rootSchemaPayloadToFeatureSchema(client, payload): - # Technically we are querying for a Schema Node. - # But the features are the same so we just grab the feature schema id - payload['id'] = payload['normalized']['featureSchemaId'] - return Entity.FeatureSchema(client, payload) - - return PaginatedCollection(self, query_str, params, - ['rootSchemaNodes', 'nodes'], - rootSchemaPayloadToFeatureSchema, - ['rootSchemaNodes', 'nextCursor']) - - def create_ontology_from_feature_schemas(self, - name, - feature_schema_ids, - media_type=None) -> Ontology: - """ - Creates an ontology from a list of feature schema ids - - Args: - name (str): Name of the ontology - feature_schema_ids (List[str]): List of feature schema ids corresponding to - top level tools and classifications to include in the ontology - media_type (MediaType or None): Media type of a new ontology - Returns: - The created Ontology - """ - tools, classifications = [], [] - for feature_schema_id in feature_schema_ids: - feature_schema = self.get_feature_schema(feature_schema_id) - tool = ['tool'] - if 'tool' in feature_schema.normalized: - tool = feature_schema.normalized['tool'] - try: - Tool.Type(tool) - tools.append(feature_schema.normalized) - except ValueError: - raise ValueError( - f"Tool `{tool}` not in list of supported tools.") - elif 'type' in feature_schema.normalized: - classification = feature_schema.normalized['type'] - try: - Classification.Type(classification) - classifications.append(feature_schema.normalized) - except ValueError: - raise ValueError( - f"Classification `{classification}` not in list of supported classifications." - ) - else: - raise ValueError( - "Neither `tool` or `classification` found in the normalized feature schema" - ) - normalized = {'tools': tools, 'classifications': classifications} - return self.create_ontology(name, normalized, media_type) - - def delete_unused_feature_schema(self, feature_schema_id: str) -> None: - """ - Deletes a feature schema if it is not used by any ontologies or annotations - Args: - feature_schema_id (str): The id of the feature schema to delete - Example: - >>> client.delete_unused_feature_schema("cleabc1my012ioqvu5anyaabc") - """ - - endpoint = self.rest_endpoint + "/feature-schemas/" + urllib.parse.quote( - feature_schema_id) - response = requests.delete( - endpoint, - headers=self.headers, - ) - - if response.status_code != requests.codes.no_content: - raise labelbox.exceptions.LabelboxError( - "Failed to delete the feature schema, message: " + - str(response.json()['message'])) - - def delete_unused_ontology(self, ontology_id: str) -> None: - """ - Deletes an ontology if it is not used by any annotations - Args: - ontology_id (str): The id of the ontology to delete - Example: - >>> client.delete_unused_ontology("cleabc1my012ioqvu5anyaabc") - """ - endpoint = self.rest_endpoint + "/ontologies/" + urllib.parse.quote( - ontology_id) - response = requests.delete( - endpoint, - headers=self.headers, - ) - - if response.status_code != requests.codes.no_content: - raise labelbox.exceptions.LabelboxError( - "Failed to delete the ontology, message: " + - str(response.json()['message'])) - - def update_feature_schema_title(self, feature_schema_id: str, - title: str) -> FeatureSchema: - """ - Updates a title of a feature schema - Args: - feature_schema_id (str): The id of the feature schema to update - title (str): The new title of the feature schema - Returns: - The updated feature schema - Example: - >>> client.update_feature_schema_title("cleabc1my012ioqvu5anyaabc", "New Title") - """ - - endpoint = self.rest_endpoint + "/feature-schemas/" + urllib.parse.quote( - feature_schema_id) + '/definition' - response = requests.patch( - endpoint, - headers=self.headers, - json={"title": title}, - ) - - if response.status_code == requests.codes.ok: - return self.get_feature_schema(feature_schema_id) - else: - raise labelbox.exceptions.LabelboxError( - "Failed to update the feature schema, message: " + - str(response.json()['message'])) - - def upsert_feature_schema(self, feature_schema: Dict) -> FeatureSchema: - """ - Upserts a feature schema - Args: - feature_schema: Dict representing the feature schema to upsert - Returns: - The upserted feature schema - Example: - Insert a new feature schema - >>> tool = Tool(name="tool", tool=Tool.Type.BOUNDING_BOX, color="#FF0000") - >>> client.upsert_feature_schema(tool.asdict()) - Update an existing feature schema - >>> tool = Tool(feature_schema_id="cleabc1my012ioqvu5anyaabc", name="tool", tool=Tool.Type.BOUNDING_BOX, color="#FF0000") - >>> client.upsert_feature_schema(tool.asdict()) - """ - - feature_schema_id = feature_schema.get( - "featureSchemaId") or "new_feature_schema_id" - endpoint = self.rest_endpoint + "/feature-schemas/" + urllib.parse.quote( - feature_schema_id) - response = requests.put( - endpoint, - headers=self.headers, - json={"normalized": json.dumps(feature_schema)}, - ) - - if response.status_code == requests.codes.ok: - return self.get_feature_schema(response.json()['schemaId']) - else: - raise labelbox.exceptions.LabelboxError( - "Failed to upsert the feature schema, message: " + - str(response.json()['message'])) - - def insert_feature_schema_into_ontology(self, feature_schema_id: str, - ontology_id: str, - position: int) -> None: - """ - Inserts a feature schema into an ontology. If the feature schema is already in the ontology, - it will be moved to the new position. - Args: - feature_schema_id (str): The feature schema id to upsert - ontology_id (str): The id of the ontology to insert the feature schema into - position (int): The position number of the feature schema in the ontology - Example: - >>> client.insert_feature_schema_into_ontology("cleabc1my012ioqvu5anyaabc", "clefdvwl7abcgefgu3lyvcde", 2) - """ - - endpoint = self.rest_endpoint + '/ontologies/' + urllib.parse.quote( - ontology_id) + "/feature-schemas/" + urllib.parse.quote( - feature_schema_id) - response = requests.post( - endpoint, - headers=self.headers, - json={"position": position}, - ) - if response.status_code != requests.codes.created: - raise labelbox.exceptions.LabelboxError( - "Failed to insert the feature schema into the ontology, message: " - + str(response.json()['message'])) - - def get_unused_ontologies(self, after: str = None) -> List[str]: - """ - Returns a list of unused ontology ids - Args: - after (str): The cursor to use for pagination - Returns: - A list of unused ontology ids - Example: - To get the first page of unused ontology ids (100 at a time) - >>> client.get_unused_ontologies() - To get the next page of unused ontology ids - >>> client.get_unused_ontologies("cleabc1my012ioqvu5anyaabc") - """ - - endpoint = self.rest_endpoint + "/ontologies/unused" - response = requests.get( - endpoint, - headers=self.headers, - json={"after": after}, - ) - - if response.status_code == requests.codes.ok: - return response.json() - else: - raise labelbox.exceptions.LabelboxError( - "Failed to get unused ontologies, message: " + - str(response.json()['message'])) - - def get_unused_feature_schemas(self, after: str = None) -> List[str]: - """ - Returns a list of unused feature schema ids - Args: - after (str): The cursor to use for pagination - Returns: - A list of unused feature schema ids - Example: - To get the first page of unused feature schema ids (100 at a time) - >>> client.get_unused_feature_schemas() - To get the next page of unused feature schema ids - >>> client.get_unused_feature_schemas("cleabc1my012ioqvu5anyaabc") - """ - - endpoint = self.rest_endpoint + "/feature-schemas/unused" - response = requests.get( - endpoint, - headers=self.headers, - json={"after": after}, - ) - - if response.status_code == requests.codes.ok: - return response.json() - else: - raise labelbox.exceptions.LabelboxError( - "Failed to get unused feature schemas, message: " + - str(response.json()['message'])) - - def create_ontology(self, name, normalized, media_type=None) -> Ontology: - """ - Creates an ontology from normalized data - >>> normalized = {"tools" : [{'tool': 'polygon', 'name': 'cat', 'color': 'black'}], "classifications" : []} - >>> ontology = client.create_ontology("ontology-name", normalized) - - Or use the ontology builder. It is especially useful for complex ontologies - >>> normalized = OntologyBuilder(tools=[Tool(tool=Tool.Type.BBOX, name="cat", color = 'black')]).asdict() - >>> ontology = client.create_ontology("ontology-name", normalized) - - To reuse existing feature schemas, use `create_ontology_from_feature_schemas()` - More details can be found here: - https://github.com/Labelbox/labelbox-python/blob/develop/examples/basics/ontologies.ipynb - - Args: - name (str): Name of the ontology - normalized (dict): A normalized ontology payload. See above for details. - media_type (MediaType or None): Media type of a new ontology - Returns: - The created Ontology - """ - - if media_type: - if MediaType.is_supported(media_type): - media_type = media_type.value - else: - raise get_media_type_validation_error(media_type) - - query_str = """mutation upsertRootSchemaNodePyApi($data: UpsertOntologyInput!){ - upsertOntology(data: $data){ %s } - } """ % query.results_query_part(Entity.Ontology) - params = { - 'data': { - 'name': name, - 'normalized': json.dumps(normalized), - 'mediaType': media_type - } - } - res = self.execute(query_str, params) - return Entity.Ontology(self, res['upsertOntology']) - - def create_feature_schema(self, normalized): - """ - Creates a feature schema from normalized data. - >>> normalized = {'tool': 'polygon', 'name': 'cat', 'color': 'black'} - >>> feature_schema = client.create_feature_schema(normalized) - - Or use the Tool or Classification objects. It is especially useful for complex tools. - >>> normalized = Tool(tool=Tool.Type.BBOX, name="cat", color = 'black').asdict() - >>> feature_schema = client.create_feature_schema(normalized) - - Subclasses are also supported - >>> normalized = Tool( - tool=Tool.Type.SEGMENTATION, - name="cat", - classifications=[ - Classification( - class_type=Classification.Type.TEXT, - name="name" - ) - ] - ) - >>> feature_schema = client.create_feature_schema(normalized) - - More details can be found here: - https://github.com/Labelbox/labelbox-python/blob/develop/examples/basics/ontologies.ipynb - - Args: - normalized (dict): A normalized tool or classification payload. See above for details - Returns: - The created FeatureSchema. - """ - query_str = """mutation upsertRootSchemaNodePyApi($data: UpsertRootSchemaNodeInput!){ - upsertRootSchemaNode(data: $data){ %s } - } """ % query.results_query_part(Entity.FeatureSchema) - normalized = {k: v for k, v in normalized.items() if v} - params = {'data': {'normalized': json.dumps(normalized)}} - res = self.execute(query_str, params)['upsertRootSchemaNode'] - # Technically we are querying for a Schema Node. - # But the features are the same so we just grab the feature schema id - res['id'] = res['normalized']['featureSchemaId'] - return Entity.FeatureSchema(self, res) - - def get_model_run(self, model_run_id: str) -> ModelRun: - """ Gets a single ModelRun with the given ID. - - >>> model_run = client.get_model_run("") - - Args: - model_run_id (str): Unique ID of the ModelRun. - Returns: - A ModelRun object. - """ - return self._get_single(Entity.ModelRun, model_run_id) - - def assign_global_keys_to_data_rows( - self, - global_key_to_data_row_inputs: List[Dict[str, str]], - timeout_seconds=60) -> Dict[str, Union[str, List[Any]]]: - """ - Assigns global keys to data rows. - - Args: - A list of dicts containing data_row_id and global_key. - Returns: - Dictionary containing 'status', 'results' and 'errors'. - - 'Status' contains the outcome of this job. It can be one of - 'Success', 'Partial Success', or 'Failure'. - - 'Results' contains the successful global_key assignments, including - global_keys that have been sanitized to Labelbox standards. - - 'Errors' contains global_key assignments that failed, along with - the reasons for failure. - Examples: - >>> global_key_data_row_inputs = [ - {"data_row_id": "cl7asgri20yvo075b4vtfedjb", "global_key": "key1"}, - {"data_row_id": "cl7asgri10yvg075b4pz176ht", "global_key": "key2"}, - ] - >>> job_result = client.assign_global_keys_to_data_rows(global_key_data_row_inputs) - >>> print(job_result['status']) - Partial Success - >>> print(job_result['results']) - [{'data_row_id': 'cl7tv9wry00hlka6gai588ozv', 'global_key': 'gk', 'sanitized': False}] - >>> print(job_result['errors']) - [{'data_row_id': 'cl7tpjzw30031ka6g4evqdfoy', 'global_key': 'gk"', 'error': 'Invalid global key'}] - """ - - def _format_successful_rows(rows: Dict[str, str], - sanitized: bool) -> List[Dict[str, str]]: - return [{ - 'data_row_id': r['dataRowId'], - 'global_key': r['globalKey'], - 'sanitized': sanitized - } for r in rows] - - def _format_failed_rows(rows: Dict[str, str], - error_msg: str) -> List[Dict[str, str]]: - return [{ - 'data_row_id': r['dataRowId'], - 'global_key': r['globalKey'], - 'error': error_msg - } for r in rows] - - # Validate input dict - validation_errors = [] - for input in global_key_to_data_row_inputs: - if "data_row_id" not in input or "global_key" not in input: - validation_errors.append(input) - if len(validation_errors) > 0: - raise ValueError( - f"Must provide a list of dicts containing both `data_row_id` and `global_key`. The following dict(s) are invalid: {validation_errors}." - ) - - # Start assign global keys to data rows job - query_str = """mutation assignGlobalKeysToDataRowsPyApi($globalKeyDataRowLinks: [AssignGlobalKeyToDataRowInput!]!) { - assignGlobalKeysToDataRows(data: {assignInputs: $globalKeyDataRowLinks}) { - jobId - } - } - """ - params = { - 'globalKeyDataRowLinks': [{ - utils.camel_case(key): value for key, value in input.items() - } for input in global_key_to_data_row_inputs] - } - assign_global_keys_to_data_rows_job = self.execute(query_str, params) - - # Query string for retrieving job status and result, if job is done - result_query_str = """query assignGlobalKeysToDataRowsResultPyApi($jobId: ID!) { - assignGlobalKeysToDataRowsResult(jobId: {id: $jobId}) { - jobStatus - data { - sanitizedAssignments { - dataRowId - globalKey - } - invalidGlobalKeyAssignments { - dataRowId - globalKey - } - unmodifiedAssignments { - dataRowId - globalKey - } - accessDeniedAssignments { - dataRowId - globalKey - } - }}} - """ - result_params = { - "jobId": - assign_global_keys_to_data_rows_job["assignGlobalKeysToDataRows" - ]["jobId"] - } - - # Poll job status until finished, then retrieve results - sleep_time = 2 - start_time = time.time() - while True: - res = self.execute(result_query_str, result_params) - if res["assignGlobalKeysToDataRowsResult"][ - "jobStatus"] == "COMPLETE": - results, errors = [], [] - res = res['assignGlobalKeysToDataRowsResult']['data'] - # Successful assignments - results.extend( - _format_successful_rows(rows=res['sanitizedAssignments'], - sanitized=True)) - results.extend( - _format_successful_rows(rows=res['unmodifiedAssignments'], - sanitized=False)) - # Failed assignments - errors.extend( - _format_failed_rows( - rows=res['invalidGlobalKeyAssignments'], - error_msg= - "Invalid assignment. Either DataRow does not exist, or globalKey is invalid" - )) - errors.extend( - _format_failed_rows(rows=res['accessDeniedAssignments'], - error_msg="Access denied to Data Row")) - - if not errors: - status = CollectionJobStatus.SUCCESS.value - elif errors and results: - status = CollectionJobStatus.PARTIAL_SUCCESS.value - else: - status = CollectionJobStatus.FAILURE.value - - if errors: - logger.warning( - "There are errors present. Please look at 'errors' in the returned dict for more details" - ) - - return { - "status": status, - "results": results, - "errors": errors, - } - elif res["assignGlobalKeysToDataRowsResult"][ - "jobStatus"] == "FAILED": - raise labelbox.exceptions.LabelboxError( - "Job assign_global_keys_to_data_rows failed.") - current_time = time.time() - if current_time - start_time > timeout_seconds: - raise labelbox.exceptions.TimeoutError( - "Timed out waiting for assign_global_keys_to_data_rows job to complete." - ) - time.sleep(sleep_time) - - def get_data_row_ids_for_global_keys( - self, - global_keys: List[str], - timeout_seconds=60) -> Dict[str, Union[str, List[Any]]]: - """ - Gets data row ids for a list of global keys. - - Deprecation Notice: This function will soon no longer return 'Deleted Data Rows' - as part of the 'results'. Global keys for deleted data rows will soon be placed - under 'Data Row not found' portion. - - Args: - A list of global keys - Returns: - Dictionary containing 'status', 'results' and 'errors'. - - 'Status' contains the outcome of this job. It can be one of - 'Success', 'Partial Success', or 'Failure'. - - 'Results' contains a list of the fetched corresponding data row ids in the input order. - For data rows that cannot be fetched due to an error, or data rows that do not exist, - empty string is returned at the position of the respective global_key. - More error information can be found in the 'Errors' section. - - 'Errors' contains a list of global_keys that could not be fetched, along - with the failure reason - Examples: - >>> job_result = client.get_data_row_ids_for_global_keys(["key1","key2"]) - >>> print(job_result['status']) - Partial Success - >>> print(job_result['results']) - ['cl7tv9wry00hlka6gai588ozv', 'cl7tv9wxg00hpka6gf8sh81bj'] - >>> print(job_result['errors']) - [{'global_key': 'asdf', 'error': 'Data Row not found'}] - """ - - def _format_failed_rows(rows: List[str], - error_msg: str) -> List[Dict[str, str]]: - return [{'global_key': r, 'error': error_msg} for r in rows] - - # Start get data rows for global keys job - query_str = """query getDataRowsForGlobalKeysPyApi($globalKeys: [ID!]!) { - dataRowsForGlobalKeys(where: {ids: $globalKeys}) { jobId}} - """ - params = {"globalKeys": global_keys} - data_rows_for_global_keys_job = self.execute(query_str, params) - - # Query string for retrieving job status and result, if job is done - result_query_str = """query getDataRowsForGlobalKeysResultPyApi($jobId: ID!) { - dataRowsForGlobalKeysResult(jobId: {id: $jobId}) { data { - fetchedDataRows { id } - notFoundGlobalKeys - accessDeniedGlobalKeys - } jobStatus}} - """ - result_params = { - "jobId": - data_rows_for_global_keys_job["dataRowsForGlobalKeys"]["jobId"] - } - - # Poll job status until finished, then retrieve results - sleep_time = 2 - start_time = time.time() - while True: - res = self.execute(result_query_str, result_params) - if res["dataRowsForGlobalKeysResult"]['jobStatus'] == "COMPLETE": - data = res["dataRowsForGlobalKeysResult"]['data'] - results, errors = [], [] - results.extend([row['id'] for row in data['fetchedDataRows']]) - errors.extend( - _format_failed_rows(data['notFoundGlobalKeys'], - "Data Row not found")) - errors.extend( - _format_failed_rows(data['accessDeniedGlobalKeys'], - "Access denied to Data Row")) - - # Invalid results may contain empty string, so we must filter - # them prior to checking for PARTIAL_SUCCESS - filtered_results = list(filter(lambda r: r != '', results)) - if not errors: - status = CollectionJobStatus.SUCCESS.value - elif errors and len(filtered_results) > 0: - status = CollectionJobStatus.PARTIAL_SUCCESS.value - else: - status = CollectionJobStatus.FAILURE.value - - if errors: - logger.warning( - "There are errors present. Please look at 'errors' in the returned dict for more details" - ) - - return {"status": status, "results": results, "errors": errors} - elif res["dataRowsForGlobalKeysResult"]['jobStatus'] == "FAILED": - raise labelbox.exceptions.LabelboxError( - "Job dataRowsForGlobalKeys failed.") - current_time = time.time() - if current_time - start_time > timeout_seconds: - raise labelbox.exceptions.TimeoutError( - "Timed out waiting for get_data_rows_for_global_keys job to complete." - ) - time.sleep(sleep_time) - - def clear_global_keys( - self, - global_keys: List[str], - timeout_seconds=60) -> Dict[str, Union[str, List[Any]]]: - """ - Clears global keys for the data rows tha correspond to the global keys provided. - - Args: - A list of global keys - Returns: - Dictionary containing 'status', 'results' and 'errors'. - - 'Status' contains the outcome of this job. It can be one of - 'Success', 'Partial Success', or 'Failure'. - - 'Results' contains a list global keys that were successfully cleared. - - 'Errors' contains a list of global_keys correspond to the data rows that could not be - modified, accessed by the user, or not found. - Examples: - >>> job_result = client.clear_global_keys(["key1","key2","notfoundkey"]) - >>> print(job_result['status']) - Partial Success - >>> print(job_result['results']) - ['key1', 'key2'] - >>> print(job_result['errors']) - [{'global_key': 'notfoundkey', 'error': 'Failed to find data row matching provided global key'}] - """ - - def _format_failed_rows(rows: List[str], - error_msg: str) -> List[Dict[str, str]]: - return [{'global_key': r, 'error': error_msg} for r in rows] - - # Start get data rows for global keys job - query_str = """mutation clearGlobalKeysPyApi($globalKeys: [ID!]!) { - clearGlobalKeys(where: {ids: $globalKeys}) { jobId}} - """ - params = {"globalKeys": global_keys} - clear_global_keys_job = self.execute(query_str, params) - - # Query string for retrieving job status and result, if job is done - result_query_str = """query clearGlobalKeysResultPyApi($jobId: ID!) { - clearGlobalKeysResult(jobId: {id: $jobId}) { data { - clearedGlobalKeys - failedToClearGlobalKeys - notFoundGlobalKeys - accessDeniedGlobalKeys - } jobStatus}} - """ - result_params = { - "jobId": clear_global_keys_job["clearGlobalKeys"]["jobId"] - } - # Poll job status until finished, then retrieve results - sleep_time = 2 - start_time = time.time() - while True: - res = self.execute(result_query_str, result_params) - if res["clearGlobalKeysResult"]['jobStatus'] == "COMPLETE": - data = res["clearGlobalKeysResult"]['data'] - results, errors = [], [] - results.extend(data['clearedGlobalKeys']) - errors.extend( - _format_failed_rows(data['failedToClearGlobalKeys'], - "Clearing global key failed")) - errors.extend( - _format_failed_rows( - data['notFoundGlobalKeys'], - "Failed to find data row matching provided global key")) - errors.extend( - _format_failed_rows( - data['accessDeniedGlobalKeys'], - "Denied access to modify data row matching provided global key" - )) - - if not errors: - status = CollectionJobStatus.SUCCESS.value - elif errors and len(results) > 0: - status = CollectionJobStatus.PARTIAL_SUCCESS.value - else: - status = CollectionJobStatus.FAILURE.value - - if errors: - logger.warning( - "There are errors present. Please look at 'errors' in the returned dict for more details" - ) - - return {"status": status, "results": results, "errors": errors} - elif res["clearGlobalKeysResult"]['jobStatus'] == "FAILED": - raise labelbox.exceptions.LabelboxError( - "Job clearGlobalKeys failed.") - current_time = time.time() - if current_time - start_time > timeout_seconds: - raise labelbox.exceptions.TimeoutError( - "Timed out waiting for clear_global_keys job to complete.") - time.sleep(sleep_time) - - def get_catalog(self) -> Catalog: - return Catalog(client=self) - - def get_catalog_slice(self, slice_id) -> CatalogSlice: - """ - Fetches a Catalog Slice by ID. - - Args: - slice_id (str): The ID of the Slice - Returns: - CatalogSlice - """ - query_str = """query getSavedQueryPyApi($id: ID!) { - getSavedQuery(id: $id) { - id - name - description - filter - createdAt - updatedAt - } - } - """ - res = self.execute(query_str, {'id': slice_id}) - return Entity.CatalogSlice(self, res['getSavedQuery']) - - def is_feature_schema_archived(self, ontology_id: str, - feature_schema_id: str) -> bool: - """ - Returns true if a feature schema is archived in the specified ontology, returns false otherwise. - - Args: - feature_schema_id (str): The ID of the feature schema - ontology_id (str): The ID of the ontology - Returns: - bool - """ - - ontology_endpoint = self.rest_endpoint + "/ontologies/" + urllib.parse.quote( - ontology_id) - response = requests.get( - ontology_endpoint, - headers=self.headers, - ) - - if response.status_code == requests.codes.ok: - feature_schema_nodes = response.json()['featureSchemaNodes'] - tools = feature_schema_nodes['tools'] - classifications = feature_schema_nodes['classifications'] - relationships = feature_schema_nodes['relationships'] - feature_schema_node_list = tools + classifications + relationships - filtered_feature_schema_nodes = [ - feature_schema_node - for feature_schema_node in feature_schema_node_list - if feature_schema_node['featureSchemaId'] == feature_schema_id - ] - if filtered_feature_schema_nodes: - return bool(filtered_feature_schema_nodes[0]['archived']) - else: - raise labelbox.exceptions.LabelboxError( - "The specified feature schema was not in the ontology.") - - elif response.status_code == 404: - raise labelbox.exceptions.ResourceNotFoundError( - Ontology, ontology_id) - else: - raise labelbox.exceptions.LabelboxError( - "Failed to get the feature schema archived status.") - - def get_model_slice(self, slice_id) -> ModelSlice: - """ - Fetches a Model Slice by ID. - - Args: - slice_id (str): The ID of the Slice - Returns: - ModelSlice - """ - query_str = """ - query getSavedQueryPyApi($id: ID!) { - getSavedQuery(id: $id) { - id - name - description - filter - createdAt - updatedAt - } - } - """ - res = self.execute(query_str, {"id": slice_id}) - if res is None or res["getSavedQuery"] is None: - raise labelbox.exceptions.ResourceNotFoundError( - ModelSlice, slice_id) - - return Entity.ModelSlice(self, res["getSavedQuery"]) - - def delete_feature_schema_from_ontology( - self, ontology_id: str, - feature_schema_id: str) -> DeleteFeatureFromOntologyResult: - """ - Deletes or archives a feature schema from an ontology. - If the feature schema is a root level node with associated labels, it will be archived. - If the feature schema is a nested node in the ontology and does not have associated labels, it will be deleted. - If the feature schema is a nested node in the ontology and has associated labels, it will not be deleted. - - Args: - ontology_id (str): The ID of the ontology. - feature_schema_id (str): The ID of the feature schema. - - Returns: - DeleteFeatureFromOntologyResult: The result of the feature schema removal. - - Example: - >>> client.delete_feature_schema_from_ontology(, ) - """ - ontology_endpoint = self.rest_endpoint + "/ontologies/" + urllib.parse.quote( - ontology_id) + "/feature-schemas/" + urllib.parse.quote( - feature_schema_id) - response = requests.delete( - ontology_endpoint, - headers=self.headers, - ) - - if response.status_code == requests.codes.ok: - response_json = response.json() - if response_json['archived'] == True: - logger.info( - 'Feature schema was archived from the ontology because it had associated labels.' - ) - elif response_json['deleted'] == True: - logger.info( - 'Feature schema was successfully removed from the ontology') - result = DeleteFeatureFromOntologyResult() - result.archived = bool(response_json['archived']) - result.deleted = bool(response_json['deleted']) - return result - else: - raise labelbox.exceptions.LabelboxError( - "Failed to remove feature schema from ontology, message: " + - str(response.json()['message'])) - - def unarchive_feature_schema_node(self, ontology_id: str, - root_feature_schema_id: str) -> None: - """ - Unarchives a feature schema node in an ontology. - Only root level feature schema nodes can be unarchived. - Args: - ontology_id (str): The ID of the ontology - root_feature_schema_id (str): The ID of the root level feature schema - Returns: - None - """ - ontology_endpoint = self.rest_endpoint + "/ontologies/" + urllib.parse.quote( - ontology_id) + '/feature-schemas/' + urllib.parse.quote( - root_feature_schema_id) + '/unarchive' - response = requests.patch( - ontology_endpoint, - headers=self.headers, - ) - if response.status_code == requests.codes.ok: - if not bool(response.json()['unarchived']): - raise labelbox.exceptions.LabelboxError( - "Failed unarchive the feature schema.") - else: - raise labelbox.exceptions.LabelboxError( - "Failed unarchive the feature schema node, message: ", - response.text) - - def get_batch(self, project_id: str, batch_id: str) -> Entity.Batch: - # obtain batch entity to return - get_batch_str = """query %s($projectId: ID!, $batchId: ID!) { - project(where: {id: $projectId}) { - batches(where: {id: $batchId}) { - nodes { - %s - } - } - } - } - """ % ("getProjectBatchPyApi", - query.results_query_part(Entity.Batch)) - - batch = self.execute( - get_batch_str, { - "projectId": project_id, - "batchId": batch_id - }, - timeout=180.0, - experimental=True)["project"]["batches"]["nodes"][0] - - return Entity.Batch(self, project_id, batch) - - def send_to_annotate_from_catalog(self, destination_project_id: str, - task_queue_id: Optional[str], - batch_name: str, - data_rows: Union[DataRowIds, GlobalKeys], - params: SendToAnnotateFromCatalogParams): - """ - Sends data rows from catalog to a specified project for annotation. - - Example usage: - >>> task = client.send_to_annotate_from_catalog( - >>> destination_project_id=DESTINATION_PROJECT_ID, - >>> task_queue_id=TASK_QUEUE_ID, - >>> batch_name="batch_name", - >>> data_rows=UniqueIds([DATA_ROW_ID]), - >>> params={ - >>> "source_project_id": - >>> SOURCE_PROJECT_ID, - >>> "override_existing_annotations_rule": - >>> ConflictResolutionStrategy.OverrideWithAnnotations - >>> }) - >>> task.wait_till_done() - - Args: - destination_project_id: The ID of the project to send the data rows to. - task_queue_id: The ID of the task queue to send the data rows to. If not specified, the data rows will be - sent to the Done workflow state. - batch_name: The name of the batch to create. If more than one batch is created, additional batches will be - named with a monotonically increasing numerical suffix, starting at "_1". - data_rows: The data rows to send to the project. - params: Additional parameters to configure the job. See SendToAnnotateFromCatalogParams for more details. - - Returns: The created task for this operation. - - """ - - mutation_str = """mutation SendToAnnotateFromCatalogPyApi($input: SendToAnnotateFromCatalogInput!) { - sendToAnnotateFromCatalog(input: $input) { - taskId - } - } - """ - - destination_task_queue = build_destination_task_queue_input( - task_queue_id) - data_rows_query = self.build_catalog_query(data_rows) - - source_model_run_id = params.get("source_model_run_id", None) - predictions_ontology_mapping = params.get( - "predictions_ontology_mapping", None) - predictions_input = build_predictions_input( - predictions_ontology_mapping, - source_model_run_id) if source_model_run_id else None - - source_project_id = params.get("source_project_id", None) - annotations_ontology_mapping = params.get( - "annotations_ontology_mapping", None) - annotations_input = build_annotations_input( - annotations_ontology_mapping, - source_project_id) if source_project_id else None - - batch_priority = params.get("batch_priority", 5) - exclude_data_rows_in_project = params.get( - "exclude_data_rows_in_project", False) - override_existing_annotations_rule = params.get( - "override_existing_annotations_rule", - ConflictResolutionStrategy.KeepExisting) - - res = self.execute( - mutation_str, { - "input": { - "destinationProjectId": - destination_project_id, - "batchInput": { - "batchName": batch_name, - "batchPriority": batch_priority - }, - "destinationTaskQueue": - destination_task_queue, - "excludeDataRowsInProject": - exclude_data_rows_in_project, - "annotationsInput": - annotations_input, - "predictionsInput": - predictions_input, - "conflictLabelsResolutionStrategy": - override_existing_annotations_rule, - "searchQuery": { - "scope": None, - "query": [data_rows_query] - }, - "ordering": { - "type": "RANDOM", - "random": { - "seed": random.randint(0, 10000) - }, - "sorting": None - }, - "sorting": - None, - "limit": - None - } - })['sendToAnnotateFromCatalog'] - - return Entity.Task.get_task(self, res['taskId']) - - @staticmethod - def build_catalog_query(data_rows: Union[DataRowIds, GlobalKeys]): - """ - Given a list of data rows, builds a query that can be used to fetch the associated data rows from the catalog. - - Args: - data_rows: A list of data rows. Can be either UniqueIds or GlobalKeys. - - Returns: A query that can be used to fetch the associated data rows from the catalog. - - """ - if isinstance(data_rows, DataRowIds): - data_rows_query = { - "type": "data_row_id", - "operator": "is", - "ids": list(data_rows) - } - elif isinstance(data_rows, GlobalKeys): - data_rows_query = { - "type": "global_key", - "operator": "is", - "ids": list(data_rows) - } - else: - raise ValueError( - f"Invalid data_rows type {type(data_rows)}. Type of data_rows must be DataRowIds or GlobalKey" - ) - return data_rows_query - - def run_foundry_app(self, model_run_name: str, data_rows: Union[DataRowIds, - GlobalKeys], - app_id: str) -> Task: - """ - Run a foundry app - - Args: - model_run_name (str): Name of a new model run to store app predictions in - data_rows (DataRowIds or GlobalKeys): Data row identifiers to run predictions on - app_id (str): Foundry app to run predictions with - """ - foundry_client = FoundryClient(self) - return foundry_client.run_app(model_run_name, data_rows, app_id) - ----- -labelbox/__init__.py -name = "labelbox" - -__version__ = "3.65.0" - -from labelbox.client import Client -from labelbox.schema.project import Project -from labelbox.schema.model import Model -from labelbox.schema.bulk_import_request import BulkImportRequest -from labelbox.schema.annotation_import import MALPredictionImport, MEAPredictionImport, LabelImport, MEAToMALPredictionImport -from labelbox.schema.dataset import Dataset -from labelbox.schema.data_row import DataRow -from labelbox.schema.catalog import Catalog -from labelbox.schema.enums import AnnotationImportState -from labelbox.schema.label import Label -from labelbox.schema.batch import Batch -from labelbox.schema.review import Review -from labelbox.schema.user import User -from labelbox.schema.organization import Organization -from labelbox.schema.task import Task -from labelbox.schema.export_task import StreamType, ExportTask, JsonConverter, JsonConverterOutput, FileConverter, FileConverterOutput -from labelbox.schema.labeling_frontend import LabelingFrontend, LabelingFrontendOptions -from labelbox.schema.asset_attachment import AssetAttachment -from labelbox.schema.webhook import Webhook -from labelbox.schema.ontology import Ontology, OntologyBuilder, Classification, Option, Tool, FeatureSchema -from labelbox.schema.role import Role, ProjectRole -from labelbox.schema.invite import Invite, InviteLimit -from labelbox.schema.data_row_metadata import DataRowMetadataOntology, DataRowMetadataField, DataRowMetadata, DeleteDataRowMetadata -from labelbox.schema.model_run import ModelRun, DataSplit -from labelbox.schema.benchmark import Benchmark -from labelbox.schema.iam_integration import IAMIntegration -from labelbox.schema.resource_tag import ResourceTag -from labelbox.schema.project_resource_tag import ProjectResourceTag -from labelbox.schema.media_type import MediaType -from labelbox.schema.slice import Slice, CatalogSlice, ModelSlice -from labelbox.schema.queue_mode import QueueMode -from labelbox.schema.task_queue import TaskQueue -from labelbox.schema.identifiables import UniqueIds, GlobalKeys, DataRowIds -from labelbox.schema.identifiable import UniqueId, GlobalKey - ----- -labelbox/types.py -try: - from labelbox.data.annotation_types import * -except ImportError: - raise ImportError( - "There are missing dependencies for `labelbox.types`, use `pip install labelbox[data] --upgrade` to install missing dependencies." - ) ----- -labelbox/parser.py -import json - - -class NdjsonDecoder(json.JSONDecoder): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def decode(self, s: str, *args, **kwargs): - lines = ','.join(s.splitlines()) - text = f"[{lines}]" # NOTE: this is a hack to make json.loads work for ndjson - return super().decode(text, *args, **kwargs) - - -def loads(ndjson_string, **kwargs) -> list: - kwargs.setdefault('cls', NdjsonDecoder) - return json.loads(ndjson_string, **kwargs) - - -def dumps(obj, **kwargs): - lines = map(lambda obj: json.dumps(obj, **kwargs), obj) - return '\n'.join(lines) - - -def dump(obj, io, **kwargs): - lines = dumps(obj, **kwargs) - io.write(lines) - - -def reader(io_handle, **kwargs): - for line in io_handle: - yield json.loads(line, **kwargs) - ----- -labelbox/utils.py -import datetime -import re - -from dateutil.tz import tzoffset -from dateutil.parser import isoparse as dateutil_parse -from dateutil.utils import default_tzinfo - -from urllib.parse import urlparse -from labelbox import pydantic_compat - -UPPERCASE_COMPONENTS = ['uri', 'rgb'] -ISO_DATETIME_FORMAT = '%Y-%m-%dT%H:%M:%SZ' -DFLT_TZ = tzoffset("UTC", 0000) - - -def _convert(s, sep, title): - components = re.findall(r"[A-Z][a-z0-9]*|[a-z][a-z0-9]*", s) - components = list(map(str.lower, filter(None, components))) - for i in range(len(components)): - if components[i] in UPPERCASE_COMPONENTS: - components[i] = components[i].upper() - elif title(i): - components[i] = components[i][0].upper() + components[i][1:] - return sep.join(components) - - -def camel_case(s): - """ Converts a string in [snake|camel|title]case to camelCase. """ - return _convert(s, "", lambda i: i > 0) - - -def title_case(s): - """ Converts a string in [snake|camel|title]case to TitleCase. """ - return _convert(s, "", lambda i: True) - - -def snake_case(s): - """ Converts a string in [snake|camel|title]case to snake_case. """ - return _convert(s, "_", lambda i: False) - - -def is_exactly_one_set(x, y): - return not (bool(x) == bool(y)) - - -def is_valid_uri(uri): - try: - result = urlparse(uri) - return all([result.scheme, result.netloc]) - except: - return False - - -class _CamelCaseMixin(pydantic_compat.BaseModel): - - class Config: - allow_population_by_field_name = True - alias_generator = camel_case - - -class _NoCoercionMixin: - """ - When using Unions in type annotations, pydantic will try to coerce the type - of the object to the type of the first Union member. Which results in - uninteded behavior. - - This mixin uses a class_name discriminator field to prevent pydantic from - corecing the type of the object. Add a class_name field to the class you - want to discrimniate and use this mixin class to remove the discriminator - when serializing the object. - - Example: - class ConversationData(BaseData, _NoCoercionMixin): - class_name: Literal["ConversationData"] = "ConversationData" - - """ - - def dict(self, *args, **kwargs): - res = super().dict(*args, **kwargs) - res.pop('class_name') - return res - - -def format_iso_datetime(dt: datetime.datetime) -> str: - """ - Formats a datetime object into the format: 2011-11-04T00:05:23Z - Note that datetime.isoformat() outputs 2011-11-04T00:05:23+00:00 - """ - return dt.astimezone(datetime.timezone.utc).strftime(ISO_DATETIME_FORMAT) - - -def format_iso_from_string(date_string: str) -> datetime.datetime: - """ - Converts a string even if offset is missing: 2011-11-04T00:05:23Z or 2011-11-04T00:05:23+00:00 or 2011-11-04T00:05:23 - to a datetime object. - For missing offsets, the default offset is UTC. - """ - # return datetime.datetime.fromisoformat(date_string) - return default_tzinfo(dateutil_parse(date_string), DFLT_TZ) - ----- -labelbox/exceptions.py -class LabelboxError(Exception): - """Base class for exceptions.""" - - def __init__(self, message, cause=None): - """ - Args: - message (str): Informative message about the exception. - cause (Exception): The cause of the exception (an Exception - raised by Python or another library). Optional. - """ - super().__init__(message, cause) - self.message = message - self.cause = cause - - def __str__(self): - return self.message + str(self.args) - - -class AuthenticationError(LabelboxError): - """Raised when an API key fails authentication.""" - pass - - -class AuthorizationError(LabelboxError): - """Raised when a user is unauthorized to perform the given request.""" - pass - - -class ResourceNotFoundError(LabelboxError): - """Exception raised when a given resource is not found. """ - - def __init__(self, db_object_type, params): - """ Constructor. - - Args: - db_object_type (type): A labelbox.schema.DbObject subtype. - params (dict): Dict of params identifying the sought resource. - """ - super().__init__("Resource '%s' not found for params: %r" % - (db_object_type.type_name(), params)) - self.db_object_type = db_object_type - self.params = params - - -class ResourceConflict(LabelboxError): - """Exception raised when a given resource conflicts with another. """ - pass - - -class ValidationFailedError(LabelboxError): - """Exception raised for when a GraphQL query fails validation (query cost, - etc.) E.g. a query that is too expensive, or depth is too deep. - """ - pass - - -class InternalServerError(LabelboxError): - """Nondescript prisma or 502 related errors. - - Meant to be retryable. - - TODO: these errors need better messages from platform - """ - pass - - -class InvalidQueryError(LabelboxError): - """ Indicates a malconstructed or unsupported query (either by GraphQL in - general or by Labelbox specifically). This can be the result of either client - or server side query validation. """ - pass - - -class ResourceCreationError(LabelboxError): - """ Indicates that a resource could not be created in the server side - due to a validation or transaction error""" - pass - - -class NetworkError(LabelboxError): - """Raised when an HTTPError occurs.""" - - def __init__(self, cause): - super().__init__(str(cause), cause) - self.cause = cause - - -class TimeoutError(LabelboxError): - """Raised when a request times-out.""" - pass - - -class InvalidAttributeError(LabelboxError): - """ Raised when a field (name or Field instance) is not valid or found - for a specific DB object type. """ - - def __init__(self, db_object_type, field): - super().__init__("Field(s) '%r' not valid on DB type '%s'" % - (field, db_object_type.type_name())) - self.db_object_type = db_object_type - self.field = field - - -class ApiLimitError(LabelboxError): - """ Raised when the user performs too many requests in a short period - of time. """ - pass - - -class MalformedQueryException(Exception): - """ Raised when the user submits a malformed query.""" - pass - - -class UuidError(LabelboxError): - """ Raised when there are repeat Uuid's in bulk import request.""" - pass - - -class InconsistentOntologyException(Exception): - pass - - -class MALValidationError(LabelboxError): - """Raised when user input is invalid for MAL imports.""" - pass - - -class OperationNotAllowedException(Exception): - """Raised when user does not have permissions to a resource or has exceeded usage limit""" - pass - - -class ConfidenceNotSupportedException(Exception): - """Raised when confidence is specified for unsupported annotation type""" - - -class CustomMetricsNotSupportedException(Exception): - """Raised when custom_metrics is specified for unsupported annotation type""" - - -class ProcessingWaitTimeout(Exception): - """Raised when waiting for the data rows to be processed takes longer than allowed""" - ----- -labelbox/pydantic_compat.py -from typing import Optional - - -def pydantic_import(class_name, sub_module_path: Optional[str] = None): - import importlib - import pkg_resources - - # Get the version of pydantic - pydantic_version = pkg_resources.get_distribution("pydantic").version - - # Determine the module name based on the version - module_name = "pydantic" if pydantic_version.startswith( - "1") else "pydantic.v1" - module_name = f"{module_name}.{sub_module_path}" if sub_module_path else module_name - - # Import the class from the module - klass = getattr(importlib.import_module(module_name), class_name) - - return klass - - -BaseModel = pydantic_import("BaseModel") -PrivateAttr = pydantic_import("PrivateAttr") -Field = pydantic_import("Field") -ModelField = pydantic_import("ModelField", "fields") -ValidationError = pydantic_import("ValidationError") -ErrorWrapper = pydantic_import("ErrorWrapper", "error_wrappers") - -validator = pydantic_import("validator") -root_validator = pydantic_import("root_validator") -conint = pydantic_import("conint") -conlist = pydantic_import("conlist") -constr = pydantic_import("constr") -confloat = pydantic_import("confloat") - ----- -labelbox/pagination.py -# Size of a single page in a paginated query. -from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from labelbox import Client - from labelbox.orm.db_object import DbObject - -_PAGE_SIZE = 100 - - -class PaginatedCollection: - """ An iterable collection of database objects (Projects, Labels, etc...). - - Implements automatic (transparent to the user) paginated fetching during - iteration. Intended for use by library internals and not by the end user. - For a list of attributes see __init__(...) documentation. The params of - __init__ map exactly to object attributes. - """ - - def __init__(self, - client: "Client", - query: str, - params: Dict[str, Union[str, int]], - dereferencing: Union[List[str], Dict[str, Any]], - obj_class: Union[Type["DbObject"], Callable[[Any, Any], Any]], - cursor_path: Optional[List[str]] = None, - experimental: bool = False): - """ Creates a PaginatedCollection. - - Args: - client (labelbox.Client): the client used for fetching data from DB. - query (str): Base query used for pagination. It must contain two - '%d' placeholders, the first for pagination 'skip' clause and - the second for the 'first' clause. - params (dict): Query parameters. - dereferencing (iterable): An iterable of str defining the keypath - that needs to be dereferenced in the query result in order to - reach the paginated objects of interest. - obj_class (type): The class of object to be instantiated with each - dict containing db values. - cursor_path: If not None, this is used to find the cursor - experimental: Used to call experimental endpoints - """ - self._fetched_all = False - self._data: List[Dict[str, Any]] = [] - self._data_ind = 0 - - pagination_kwargs = { - 'client': client, - 'obj_class': obj_class, - 'dereferencing': dereferencing, - 'experimental': experimental, - 'query': query, - 'params': params - } - - self.paginator = _CursorPagination( - cursor_path, ** - pagination_kwargs) if cursor_path else _OffsetPagination( - **pagination_kwargs) - - def __iter__(self): - self._data_ind = 0 - return self - - def __next__(self): - if len(self._data) <= self._data_ind: - if self._fetched_all: - raise StopIteration() - - page_data, self._fetched_all = self.paginator.get_next_page() - self._data.extend(page_data) - if len(page_data) == 0: - raise StopIteration() - - rval = self._data[self._data_ind] - self._data_ind += 1 - return rval - - def get_one(self): - """Iterates over self and returns first value - This method is idempotent - """ - for value in self: - return value - - def get_many(self, n: int): - """Iterates over self and returns first n results - This method is idempotent - - Args: - n (int): Number of elements to retrieve - """ - results = [] - i = 0 - - for value in self: - if i >= n: - break - - results.append(value) - i += 1 - - return results - - -class _Pagination(ABC): - - def __init__(self, client: "Client", obj_class: Type["DbObject"], - dereferencing: Dict[str, Any], query: str, - params: Dict[str, Any], experimental: bool): - self.client = client - self.obj_class = obj_class - self.dereferencing = dereferencing - self.experimental = experimental - self.query = query - self.params = params - - def get_page_data(self, results: Dict[str, Any]) -> List["DbObject"]: - for deref in self.dereferencing: - results = results[deref] - - return [self.obj_class(self.client, result) for result in results] - - @abstractmethod - def get_next_page(self) -> Tuple[Dict[str, Any], bool]: - ... - - -class _CursorPagination(_Pagination): - - def __init__(self, cursor_path: List[str], *args, **kwargs): - super().__init__(*args, **kwargs) - self.cursor_path = cursor_path - self.next_cursor: Optional[Any] = kwargs.get('params', {}).get('from') - - def increment_page(self, results: Dict[str, Any]): - for path in self.cursor_path: - results = results[path] - self.next_cursor = results - - def fetched_all(self) -> bool: - return not self.next_cursor - - def fetch_results(self) -> Dict[str, Any]: - page_size = self.params.get('first', _PAGE_SIZE) - self.params.update({'from': self.next_cursor, 'first': page_size}) - return self.client.execute(self.query, - self.params, - experimental=self.experimental) - - def get_next_page(self): - results = self.fetch_results() - page_data = self.get_page_data(results) - self.increment_page(results) - done = self.fetched_all() - return page_data, done - - -class _OffsetPagination(_Pagination): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._fetched_pages = 0 - - def increment_page(self): - self._fetched_pages += 1 - - def fetched_all(self, n_items: int) -> bool: - return n_items < _PAGE_SIZE - - def fetch_results(self) -> Dict[str, Any]: - query = self.query % (self._fetched_pages * _PAGE_SIZE, _PAGE_SIZE) - return self.client.execute(query, - self.params, - experimental=self.experimental) - - def get_next_page(self): - results = self.fetch_results() - page_data = self.get_page_data(results) - self.increment_page() - done = self.fetched_all(len(page_data)) - return page_data, done - ----- -labelbox/schema/catalog.py -from typing import Any, Dict, List, Optional, Union -from labelbox.orm.db_object import experimental -from labelbox.schema.export_filters import CatalogExportFilters, build_filters - -from labelbox.schema.export_params import (CatalogExportParams, - validate_catalog_export_params) -from labelbox.schema.export_task import ExportTask -from labelbox.schema.task import Task - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from labelbox import Client - - -class Catalog: - client: "Client" - - def __init__(self, client: 'Client'): - self.client = client - - def export_v2( - self, - task_name: Optional[str] = None, - filters: Union[CatalogExportFilters, Dict[str, List[str]], None] = None, - params: Optional[CatalogExportParams] = None, - ) -> Task: - """ - Creates a catalog export task with the given params, filters and returns the task. - - >>> import labelbox as lb - >>> client = lb.Client() - >>> catalog = client.get_catalog() - >>> task = catalog.export_v2( - >>> filters={ - >>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - >>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - >>> }, - >>> params={ - >>> "performance_details": False, - >>> "label_details": True - >>> }) - >>> task.wait_till_done() - >>> task.result - """ - return self._export(task_name, filters, params, False) - - @experimental - def export( - self, - task_name: Optional[str] = None, - filters: Union[CatalogExportFilters, Dict[str, List[str]], None] = None, - params: Optional[CatalogExportParams] = None, - ) -> ExportTask: - """ - Creates a catalog export task with the given params, filters and returns the task. - - >>> import labelbox as lb - >>> client = lb.Client() - >>> export_task = Catalog.export( - >>> filters={ - >>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - >>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - >>> }, - >>> params={ - >>> "performance_details": False, - >>> "label_details": True - >>> }) - >>> export_task.wait_till_done() - >>> - >>> # Return a JSON output string from the export task results/errors one by one: - >>> def json_stream_handler(output: lb.JsonConverterOutput): - >>> print(output.json_str) - >>> - >>> if export_task.has_errors(): - >>> export_task.get_stream( - >>> converter=lb.JsonConverter(), - >>> stream_type=lb.StreamType.ERRORS - >>> ).start(stream_handler=lambda error: print(error)) - >>> - >>> if export_task.has_result(): - >>> export_json = export_task.get_stream( - >>> converter=lb.JsonConverter(), - >>> stream_type=lb.StreamType.RESULT - >>> ).start(stream_handler=json_stream_handler) - """ - task = self._export(task_name, filters, params, True) - return ExportTask(task) - - def _export(self, - task_name: Optional[str] = None, - filters: Union[CatalogExportFilters, Dict[str, List[str]], - None] = None, - params: Optional[CatalogExportParams] = None, - streamable: bool = False) -> Task: - - _params = params or CatalogExportParams({ - "attachments": False, - "metadata_fields": False, - "data_row_details": False, - "project_details": False, - "performance_details": False, - "label_details": False, - "media_type_override": None, - "model_run_ids": None, - "project_ids": None, - "interpolated_frames": False, - "all_projects": False, - "all_model_runs": False, - }) - validate_catalog_export_params(_params) - - _filters = filters or CatalogExportFilters({ - "last_activity_at": None, - "label_created_at": None, - "data_row_ids": None, - "global_keys": None, - }) - - mutation_name = "exportDataRowsInCatalog" - create_task_query_str = ( - f"mutation {mutation_name}PyApi" - f"($input: ExportDataRowsInCatalogInput!)" - f"{{{mutation_name}(input: $input){{taskId}}}}") - - media_type_override = _params.get('media_type_override', None) - query_params: Dict[str, Any] = { - "input": { - "taskName": task_name, - "filters": { - "searchQuery": { - "scope": None, - "query": None, - } - }, - "params": { - "mediaTypeOverride": - media_type_override.value - if media_type_override is not None else None, - "includeAttachments": - _params.get('attachments', False), - "includeMetadata": - _params.get('metadata_fields', False), - "includeDataRowDetails": - _params.get('data_row_details', False), - "includeProjectDetails": - _params.get('project_details', False), - "includePerformanceDetails": - _params.get('performance_details', False), - "includeLabelDetails": - _params.get('label_details', False), - "includeInterpolatedFrames": - _params.get('interpolated_frames', False), - "projectIds": - _params.get('project_ids', None), - "modelRunIds": - _params.get('model_run_ids', None), - "allProjects": - _params.get('all_projects', False), - "allModelRuns": - _params.get('all_model_runs', False), - }, - "streamable": streamable, - } - } - - search_query = build_filters(self.client, _filters) - query_params["input"]["filters"]["searchQuery"]["query"] = search_query - - res = self.client.execute(create_task_query_str, - query_params, - error_log_key="errors") - res = res[mutation_name] - task_id = res["taskId"] - return Task.get_task(self.client, task_id) - ----- -labelbox/schema/model_run.py -# type: ignore -import logging -import os -import time -import warnings -from enum import Enum -from pathlib import Path -from typing import TYPE_CHECKING, Dict, Iterable, Union, List, Optional, Any - -import requests - -from labelbox import parser -from labelbox.orm.db_object import DbObject, experimental -from labelbox.orm.model import Field, Relationship, Entity -from labelbox.orm.query import results_query_part -from labelbox.pagination import PaginatedCollection -from labelbox.schema.conflict_resolution_strategy import ConflictResolutionStrategy -from labelbox.schema.export_params import ModelRunExportParams -from labelbox.schema.export_task import ExportTask -from labelbox.schema.identifiables import UniqueIds, GlobalKeys, DataRowIds -from labelbox.schema.send_to_annotate_params import SendToAnnotateFromModelParams, build_destination_task_queue_input, \ - build_predictions_input -from labelbox.schema.task import Task - -if TYPE_CHECKING: - from labelbox import MEAPredictionImport - from labelbox.types import Label - -logger = logging.getLogger(__name__) - -DATAROWS_IMPORT_LIMIT = 25000 - - -class DataSplit(Enum): - TRAINING = "TRAINING" - TEST = "TEST" - VALIDATION = "VALIDATION" - UNASSIGNED = "UNASSIGNED" - - -class ModelRun(DbObject): - name = Field.String("name") - updated_at = Field.DateTime("updated_at") - created_at = Field.DateTime("created_at") - created_by_id = Field.String("created_by_id", "createdBy") - model_id = Field.String("model_id") - training_metadata = Field.Json("training_metadata") - - class Status(Enum): - EXPORTING_DATA = "EXPORTING_DATA" - PREPARING_DATA = "PREPARING_DATA" - TRAINING_MODEL = "TRAINING_MODEL" - COMPLETE = "COMPLETE" - FAILED = "FAILED" - - def upsert_labels(self, - label_ids: Optional[List[str]] = None, - project_id: Optional[str] = None, - timeout_seconds=3600): - """ - Adds data rows and labels to a Model Run - - Args: - label_ids (list): label ids to insert - project_id (string): project uuid, all project labels will be uploaded - Either label_ids OR project_id is required but NOT both - timeout_seconds (float): Max waiting time, in seconds. - Returns: - ID of newly generated async task - - """ - - use_label_ids = label_ids is not None and len(label_ids) > 0 - use_project_id = project_id is not None - - if not use_label_ids and not use_project_id: - raise ValueError( - "Must provide at least one label id or a project id") - - if use_label_ids and use_project_id: - raise ValueError("Must only one of label ids, project id") - - if use_label_ids: - return self._upsert_labels_by_label_ids(label_ids, timeout_seconds) - else: # use_project_id - return self._upsert_labels_by_project_id(project_id, - timeout_seconds) - - def _upsert_labels_by_label_ids(self, label_ids: List[str], - timeout_seconds: int): - mutation_name = 'createMEAModelRunLabelRegistrationTask' - create_task_query_str = """mutation createMEAModelRunLabelRegistrationTaskPyApi($modelRunId: ID!, $labelIds : [ID!]!) { - %s(where : { id : $modelRunId}, data : {labelIds: $labelIds})} - """ % (mutation_name) - - res = self.client.execute(create_task_query_str, { - 'modelRunId': self.uid, - 'labelIds': label_ids - }) - task_id = res[mutation_name] - - status_query_str = """query MEALabelRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){ - MEALabelRegistrationTaskStatus(where: $where) {status errorMessage} - } - """ - return self._wait_until_done(lambda: self.client.execute( - status_query_str, {'where': { - 'id': task_id - }})['MEALabelRegistrationTaskStatus'], - timeout_seconds=timeout_seconds) - - def _upsert_labels_by_project_id(self, project_id: str, - timeout_seconds: int): - mutation_name = 'createMEAModelRunProjectLabelRegistrationTask' - create_task_query_str = """mutation createMEAModelRunProjectLabelRegistrationTaskPyApi($modelRunId: ID!, $projectId : ID!) { - %s(where : { modelRunId : $modelRunId, projectId: $projectId})} - """ % (mutation_name) - - res = self.client.execute(create_task_query_str, { - 'modelRunId': self.uid, - 'projectId': project_id - }) - task_id = res[mutation_name] - - status_query_str = """query MEALabelRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){ - MEALabelRegistrationTaskStatus(where: $where) {status errorMessage} - } - """ - return self._wait_until_done(lambda: self.client.execute( - status_query_str, {'where': { - 'id': task_id - }})['MEALabelRegistrationTaskStatus'], - timeout_seconds=timeout_seconds) - - def upsert_data_rows(self, - data_row_ids=None, - global_keys=None, - timeout_seconds=3600): - """ Adds data rows to a Model Run without any associated labels - Args: - data_row_ids (list): data row ids to add to model run - global_keys (list): global keys for data rows to add to model run - timeout_seconds (float): Max waiting time, in seconds. - Returns: - ID of newly generated async task - """ - - mutation_name = 'createMEAModelRunDataRowRegistrationTask' - create_task_query_str = """mutation createMEAModelRunDataRowRegistrationTaskPyApi($modelRunId: ID!, $dataRowIds: [ID!], $globalKeys: [ID!]) { - %s(where : { id : $modelRunId}, data : {dataRowIds: $dataRowIds, globalKeys: $globalKeys})} - """ % (mutation_name) - - res = self.client.execute( - create_task_query_str, { - 'modelRunId': self.uid, - 'dataRowIds': data_row_ids, - 'globalKeys': global_keys - }) - task_id = res[mutation_name] - - status_query_str = """query MEADataRowRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){ - MEADataRowRegistrationTaskStatus(where: $where) {status errorMessage} - } - """ - return self._wait_until_done(lambda: self.client.execute( - status_query_str, {'where': { - 'id': task_id - }})['MEADataRowRegistrationTaskStatus'], - timeout_seconds=timeout_seconds) - - def _wait_until_done(self, status_fn, timeout_seconds=120, sleep_time=5): - # Do not use this function outside of the scope of upsert_data_rows or upsert_labels. It could change. - original_timeout = timeout_seconds - while True: - res = status_fn() - if res['status'] == 'COMPLETE': - return True - elif res['status'] == 'FAILED': - raise Exception(f"Job failed.") - timeout_seconds -= sleep_time - if timeout_seconds <= 0: - raise TimeoutError( - f"Unable to complete import within {original_timeout} seconds." - ) - time.sleep(sleep_time) - - def upsert_predictions_and_send_to_project( - self, - name: str, - predictions: Union[str, Path, Iterable[Dict]], - project_id: str, - priority: Optional[int] = 5, - ) -> 'MEAPredictionImport': # type: ignore - """ - Provides a convenient way to execute the following steps in a single function call: - 1. Upload predictions to a Model - 2. Create a batch from data rows that had predictions assocated with them - 3. Attach the batch to a project - 4. Add those same predictions to the project as MAL annotations - - Note that partial successes are possible. - If it is important that all stages are successful then check the status of each individual task - with task.errors. E.g. - - >>> mea_import_job, batch, mal_import_job = upsert_predictions_and_send_to_project(name, predictions, project_id) - >>> # handle mea import job successfully created (check for job failure or partial failures) - >>> print(mea_import_job.status, mea_import_job.errors) - >>> if batch is None: - >>> # Handle batch creation failure - >>> if mal_import_job is None: - >>> # Handle mal_import_job creation failure - >>> else: - >>> # handle mal import job successfully created (check for job failure or partial failures) - >>> print(mal_import_job.status, mal_import_job.errors) - - - Args: - name (str): name of the AnnotationImport job as well as the name of the batch import - predictions (Iterable): - iterable of annotation rows - project_id (str): id of the project to import into - priority (int): priority of the job - Returns: - Tuple[MEAPredictionImport, Batch, MEAToMALPredictionImport] - If any of these steps fail the return value will be None. - - """ - kwargs = dict(client=self.client, model_run_id=self.uid, name=name) - project = self.client.get_project(project_id) - import_job = self.add_predictions(name, predictions) - prediction_statuses = import_job.statuses - mea_to_mal_data_rows = list( - set([ - row['dataRow']['id'] - for row in prediction_statuses - if row['status'] == 'SUCCESS' - ])) - - if not mea_to_mal_data_rows: - # 0 successful model predictions imported - return import_job, None, None - - elif len(mea_to_mal_data_rows) >= DATAROWS_IMPORT_LIMIT: - mea_to_mal_data_rows = mea_to_mal_data_rows[:DATAROWS_IMPORT_LIMIT] - logger.warning( - f"Exeeded max data row limit {len(mea_to_mal_data_rows)}, trimmed down to {DATAROWS_IMPORT_LIMIT} data rows." - ) - - try: - batch = project.create_batch(name, mea_to_mal_data_rows, priority) - except Exception as e: - logger.warning(f"Failed to create batch. Messsage : {e}.") - # Unable to create batch - return import_job, None, None - - try: - mal_prediction_import = Entity.MEAToMALPredictionImport.create_for_model_run_data_rows( - data_row_ids=mea_to_mal_data_rows, - project_id=project_id, - **kwargs) - mal_prediction_import.wait_until_done() - except Exception as e: - logger.warning( - f"Failed to create MEA to MAL prediction import. Message : {e}." - ) - # Unable to create mea to mal prediction import - return import_job, batch, None - - return import_job, batch, mal_prediction_import - - def add_predictions( - self, - name: str, - predictions: Union[str, Path, Iterable[Dict], Iterable["Label"]], - ) -> 'MEAPredictionImport': # type: ignore - """ - Uploads predictions to a new Editor project. - - Args: - name (str): name of the AnnotationImport job - predictions (str or Path or Iterable): url that is publicly accessible by Labelbox containing an - ndjson file - OR local path to an ndjson file - OR iterable of annotation rows - - Returns: - AnnotationImport - """ - kwargs = dict(client=self.client, model_run_id=self.uid, name=name) - if isinstance(predictions, str) or isinstance(predictions, Path): - if os.path.exists(predictions): - return Entity.MEAPredictionImport.create_from_file( - path=str(predictions), **kwargs) - else: - return Entity.MEAPredictionImport.create_from_url( - url=str(predictions), **kwargs) - elif isinstance(predictions, Iterable): - return Entity.MEAPredictionImport.create_from_objects( - predictions=predictions, **kwargs) - else: - raise ValueError( - f'Invalid predictions given of type: {type(predictions)}') - - def model_run_data_rows(self): - query_str = """query modelRunPyApi($modelRunId: ID!, $from : String, $first: Int){ - annotationGroups(where: {modelRunId: {id: $modelRunId}}, after: $from, first: $first) - {nodes{%s},pageInfo{endCursor}} - } - """ % (results_query_part(ModelRunDataRow)) - return PaginatedCollection( - self.client, query_str, {'modelRunId': self.uid}, - ['annotationGroups', 'nodes'], - lambda client, res: ModelRunDataRow(client, self.model_id, res), - ['annotationGroups', 'pageInfo', 'endCursor']) - - def delete(self): - """ Deletes specified Model Run. - - Returns: - Query execution success. - """ - ids_param = "ids" - query_str = """mutation DeleteModelRunPyApi($%s: ID!) { - deleteModelRuns(where: {ids: [$%s]})}""" % (ids_param, ids_param) - self.client.execute(query_str, {ids_param: str(self.uid)}) - - def delete_model_run_data_rows(self, data_row_ids: List[str]): - """ Deletes data rows from Model Runs. - - Args: - data_row_ids (list): List of data row ids to delete from the Model Run. - Returns: - Query execution success. - """ - model_run_id_param = "modelRunId" - data_row_ids_param = "dataRowIds" - query_str = """mutation DeleteModelRunDataRowsPyApi($%s: ID!, $%s: [ID!]!) { - deleteModelRunDataRows(where: {modelRunId: $%s, dataRowIds: $%s})}""" % ( - model_run_id_param, data_row_ids_param, model_run_id_param, - data_row_ids_param) - self.client.execute(query_str, { - model_run_id_param: self.uid, - data_row_ids_param: data_row_ids - }) - - @experimental - def assign_data_rows_to_split(self, - data_row_ids: List[str] = None, - split: Union[DataSplit, str] = None, - global_keys: List[str] = None, - timeout_seconds=120): - - split_value = split.value if isinstance(split, DataSplit) else split - valid_splits = DataSplit._member_names_ - - if split_value is None or split_value not in valid_splits: - raise ValueError( - f"`split` must be one of : `{valid_splits}`. Found : `{split}`") - - task_id = self.client.execute( - """mutation assignDataSplitPyApi($modelRunId: ID!, $data: CreateAssignDataRowsToDataSplitTaskInput!){ - createAssignDataRowsToDataSplitTask(modelRun : {id: $modelRunId}, data: $data)} - """, { - 'modelRunId': self.uid, - 'data': { - 'assignments': [{ - 'split': split_value, - 'dataRowIds': data_row_ids, - 'globalKeys': global_keys, - }] - } - }, - experimental=True)['createAssignDataRowsToDataSplitTask'] - - status_query_str = """query assignDataRowsToDataSplitTaskStatusPyApi($id: ID!){ - assignDataRowsToDataSplitTaskStatus(where: {id : $id}){status errorMessage}} - """ - - return self._wait_until_done(lambda: self.client.execute( - status_query_str, {'id': task_id}, experimental=True)[ - 'assignDataRowsToDataSplitTaskStatus'], - timeout_seconds=timeout_seconds) - - @experimental - def update_status(self, - status: Union[str, "ModelRun.Status"], - metadata: Optional[Dict[str, str]] = None, - error_message: Optional[str] = None): - - status_value = status.value if isinstance(status, - ModelRun.Status) else status - if status_value not in ModelRun.Status._member_names_: - raise ValueError( - f"Status must be one of : `{ModelRun.Status._member_names_}`. Found : `{status_value}`" - ) - - data: Dict[str, Any] = {'status': status_value} - if error_message: - data['errorMessage'] = error_message - - if metadata: - data['metadata'] = metadata - - self.client.execute( - """mutation setPipelineStatusPyApi($modelRunId: ID!, $data: UpdateTrainingPipelineInput!){ - updateTrainingPipeline(modelRun: {id : $modelRunId}, data: $data){status} - } - """, { - 'modelRunId': self.uid, - 'data': data - }, - experimental=True) - - @experimental - def update_config(self, config: Dict[str, Any]) -> Dict[str, Any]: - """ - Updates the Model Run's training metadata config - Args: - config (dict): A dictionary of keys and values - Returns: - Model Run id and updated training metadata - """ - data: Dict[str, Any] = {'config': config} - res = self.client.execute( - """mutation updateModelRunConfigPyApi($modelRunId: ID!, $data: UpdateModelRunConfigInput!){ - updateModelRunConfig(modelRun: {id : $modelRunId}, data: $data){trainingMetadata} - } - """, { - 'modelRunId': self.uid, - 'data': data - }, - experimental=True) - return res["updateModelRunConfig"] - - @experimental - def reset_config(self) -> Dict[str, Any]: - """ - Resets Model Run's training metadata config - Returns: - Model Run id and reset training metadata - """ - res = self.client.execute( - """mutation resetModelRunConfigPyApi($modelRunId: ID!){ - resetModelRunConfig(modelRun: {id : $modelRunId}){trainingMetadata} - } - """, {'modelRunId': self.uid}, - experimental=True) - return res["resetModelRunConfig"] - - @experimental - def get_config(self) -> Dict[str, Any]: - """ - Gets Model Run's training metadata - Returns: - training metadata as a dictionary - """ - res = self.client.execute("""query ModelRunPyApi($modelRunId: ID!){ - modelRun(where: {id : $modelRunId}){trainingMetadata} - } - """, {'modelRunId': self.uid}, - experimental=True) - return res["modelRun"]["trainingMetadata"] - - @experimental - def export_labels( - self, - download: bool = False, - timeout_seconds: int = 600 - ) -> Optional[Union[str, List[Dict[Any, Any]]]]: - """ - Experimental. To use, make sure client has enable_experimental=True. - - Fetches Labels from the ModelRun - - Args: - download (bool): Returns the url if False - Returns: - URL of the data file with this ModelRun's labels. - If download=True, this instead returns the contents as NDJSON format. - If the server didn't generate during the `timeout_seconds` period, - None is returned. - """ - warnings.warn( - "You are currently utilizing exports v1 for this action, which will be deprecated after April 30th, 2024. We recommend transitioning to exports v2. To view export v2 details, visit our docs: https://docs.labelbox.com/reference/label-export", - DeprecationWarning) - sleep_time = 2 - query_str = """mutation exportModelRunAnnotationsPyApi($modelRunId: ID!) { - exportModelRunAnnotations(data: {modelRunId: $modelRunId}) { - downloadUrl createdAt status - } - } - """ - - while True: - url = self.client.execute( - query_str, {'modelRunId': self.uid}, - experimental=True)['exportModelRunAnnotations']['downloadUrl'] - - if url: - if not download: - return url - else: - response = requests.get(url) - response.raise_for_status() - return parser.loads(response.content) - - timeout_seconds -= sleep_time - if timeout_seconds <= 0: - return None - - logger.debug("ModelRun '%s' label export, waiting for server...", - self.uid) - time.sleep(sleep_time) - - @experimental - def export(self, - task_name: Optional[str] = None, - params: Optional[ModelRunExportParams] = None) -> ExportTask: - """ - Creates a model run export task with the given params and returns the task. - - >>> export_task = export("my_export_task", params={"media_attributes": True}) - - """ - task = self._export(task_name, params, streamable=True) - return ExportTask(task) - - def export_v2( - self, - task_name: Optional[str] = None, - params: Optional[ModelRunExportParams] = None, - ) -> Task: - """ - Creates a model run export task with the given params and returns the task. - - >>> export_task = export_v2("my_export_task", params={"media_attributes": True}) - - """ - return self._export(task_name, params) - - def _export( - self, - task_name: Optional[str] = None, - params: Optional[ModelRunExportParams] = None, - streamable: bool = False, - ) -> Task: - mutation_name = "exportDataRowsInModelRun" - create_task_query_str = ( - f"mutation {mutation_name}PyApi" - f"($input: ExportDataRowsInModelRunInput!)" - f"{{{mutation_name}(input: $input){{taskId}}}}") - - _params = params or ModelRunExportParams() - - query_params = { - "input": { - "taskName": task_name, - "filters": { - "modelRunId": self.uid - }, - "params": { - "mediaTypeOverride": - _params.get('media_type_override', None), - "includeAttachments": - _params.get('attachments', False), - "includeMetadata": - _params.get('metadata_fields', False), - "includeDataRowDetails": - _params.get('data_row_details', False), - "includePredictions": - _params.get('predictions', False), - "includeModelRunDetails": - _params.get('model_run_details', False), - }, - "streamable": streamable - } - } - res = self.client.execute(create_task_query_str, - query_params, - error_log_key="errors") - res = res[mutation_name] - task_id = res["taskId"] - return Task.get_task(self.client, task_id) - - def send_to_annotate_from_model( - self, destination_project_id: str, task_queue_id: Optional[str], - batch_name: str, data_rows: Union[DataRowIds, GlobalKeys], - params: SendToAnnotateFromModelParams) -> Task: - """ - Sends data rows from a model run to a project for annotation. - - Example Usage: - >>> task = model_run.send_to_annotate_from_model( - >>> destination_project_id=DESTINATION_PROJECT_ID, - >>> batch_name="batch", - >>> data_rows=UniqueIds([DATA_ROW_ID]), - >>> task_queue_id=TASK_QUEUE_ID, - >>> params={}) - >>> task.wait_till_done() - - Args: - destination_project_id: The ID of the project to send the data rows to. - task_queue_id: The ID of the task queue to send the data rows to. If not specified, the data rows will be - sent to the Done workflow state. - batch_name: The name of the batch to create. If more than one batch is created, additional batches will be - named with a monotonically increasing numerical suffix, starting at "_1". - data_rows: The data rows to send to the project. - params: Additional parameters for this operation. See SendToAnnotateFromModelParams for details. - - Returns: The created task for this operation. - - """ - - mutation_str = """mutation SendToAnnotateFromMeaPyApi($input: SendToAnnotateFromMeaInput!) { - sendToAnnotateFromMea(input: $input) { - taskId - } - } - """ - - destination_task_queue = build_destination_task_queue_input( - task_queue_id) - data_rows_query = self.client.build_catalog_query(data_rows) - - predictions_ontology_mapping = params.get( - "predictions_ontology_mapping", None) - predictions_input = build_predictions_input( - predictions_ontology_mapping, self.uid) - - batch_priority = params.get("batch_priority", 5) - exclude_data_rows_in_project = params.get( - "exclude_data_rows_in_project", False) - override_existing_annotations_rule = params.get( - "override_existing_annotations_rule", - ConflictResolutionStrategy.KeepExisting) - res = self.client.execute( - mutation_str, { - "input": { - "destinationProjectId": - destination_project_id, - "batchInput": { - "batchName": batch_name, - "batchPriority": batch_priority - }, - "destinationTaskQueue": - destination_task_queue, - "excludeDataRowsInProject": - exclude_data_rows_in_project, - "annotationsInput": - None, - "predictionsInput": - predictions_input, - "conflictLabelsResolutionStrategy": - override_existing_annotations_rule, - "searchQuery": [data_rows_query], - "sourceModelRunId": - self.uid - } - })['sendToAnnotateFromMea'] - - return Entity.Task.get_task(self.client, res['taskId']) - - -class ModelRunDataRow(DbObject): - label_id = Field.String("label_id") - model_run_id = Field.String("model_run_id") - data_split = Field.Enum(DataSplit, "data_split") - data_row = Relationship.ToOne("DataRow", False, cache=True) - - def __init__(self, client, model_id, *args, **kwargs): - super().__init__(client, *args, **kwargs) - self.model_id = model_id - - @property - def url(self): - app_url = self.client.app_url - endpoint = f"{app_url}/models/{self.model_id}/{self.model_run_id}/AllDatarowsSlice/{self.uid}?view=carousel" - return endpoint - ----- -labelbox/schema/identifiables.py -from typing import List, Union - -from labelbox.schema.id_type import IdType - - -class Identifiables: - - def __init__(self, iterable, id_type: IdType): - """ - Args: - iterable: Iterable of ids (unique or global keys) - id_type: The type of id used to identify a data row. - """ - self._iterable = iterable - self._id_type = id_type - - @property - def id_type(self): - return self._id_type - - def __iter__(self): - return iter(self._iterable) - - def __getitem__(self, index): - if isinstance(index, slice): - ids = self._iterable[index] - return self.__class__(ids) # type: ignore - return self._iterable[index] - - def __len__(self): - return len(self._iterable) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self._iterable})" - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Identifiables): - return False - return self._iterable == other._iterable and self._id_type == other._id_type - - -class UniqueIds(Identifiables): - """ - Represents a collection of unique, internally generated ids. - """ - - def __init__(self, iterable: List[str]): - super().__init__(iterable, IdType.DataRowId) - - -class GlobalKeys(Identifiables): - """ - Represents a collection of user generated ids. - """ - - def __init__(self, iterable: List[str]): - super().__init__(iterable, IdType.GlobalKey) - - -DataRowIds = UniqueIds - -DataRowIdentifiers = Union[UniqueIds, GlobalKeys] - ----- -labelbox/schema/user.py -from typing import TYPE_CHECKING -from labelbox.orm.db_object import DbObject -from labelbox.orm.model import Field, Relationship - -if TYPE_CHECKING: - from labelbox import Role, Project - - -class User(DbObject): - """ A User is a registered Labelbox user (for example you) associated with - data they create or import and an Organization they belong to. - - Attributes: - updated_at (datetime) - created_at (datetime) - email (str) - name (str) - nickname (str) - intercom_hash (str) - picture (str) - is_viewer (bool) - is_external_viewer (bool) - - organization (Relationship): `ToOne` relationship to Organization - created_tasks (Relationship): `ToMany` relationship to Task - projects (Relationship): `ToMany` relationship to Project - """ - - updated_at = Field.DateTime("updated_at") - created_at = Field.DateTime("created_at") - email = Field.String("email") - name = Field.String("nickname") - nickname = Field.String("name") - intercom_hash = Field.String("intercom_hash") - picture = Field.String("picture") - is_viewer = Field.Boolean("is_viewer") - is_external_user = Field.Boolean("is_external_user") - - # Relationships - organization = Relationship.ToOne("Organization") - created_tasks = Relationship.ToMany("Task", False, "created_tasks") - projects = Relationship.ToMany("Project", False) - org_role = Relationship.ToOne("OrgRole", False) - - def update_org_role(self, role: "Role") -> None: - """ Updated the `User`s organization role. - - See client.get_roles() to get all valid roles - If you a user is converted from project level permissions to org level permissions and then convert back, their permissions will remain for each individual project - - Args: - role (Role): The role that you want to set for this user. - - """ - user_id_param = "userId" - role_id_param = "roleId" - query_str = """mutation SetOrganizationRolePyApi($%s: ID!, $%s: ID!) { - setOrganizationRole(data: {userId: $userId, roleId: $roleId}) { id name }} - """ % (user_id_param, role_id_param) - - self.client.execute(query_str, { - user_id_param: self.uid, - role_id_param: role.uid - }) - - def remove_from_project(self, project: "Project") -> None: - """ Removes a User from a project. Only used for project based users. - Project based user means their org role is "NONE" - - Args: - project (Project): Project to remove user from - - """ - self.upsert_project_role(project, self.client.get_roles()['NONE']) - - def upsert_project_role(self, project: "Project", role: "Role") -> None: - """ Updates or replaces a User's role in a project. - - Args: - project (Project): The project to update the users permissions for - role (Role): The role to assign to this user in this project. - - """ - org_role = self.org_role() - if org_role.name.upper() != 'NONE': - raise ValueError( - "User is not project based and has access to all projects") - - project_id_param = "projectId" - user_id_param = "userId" - role_id_param = "roleId" - query_str = """mutation SetProjectMembershipPyApi($%s: ID!, $%s: ID!, $%s: ID!) { - setProjectMembership(data: {%s: $userId, roleId: $%s, projectId: $%s}) {id}} - """ % (user_id_param, role_id_param, project_id_param, user_id_param, - role_id_param, project_id_param) - - self.client.execute( - query_str, { - project_id_param: project.uid, - user_id_param: self.uid, - role_id_param: role.uid - }) - ----- -labelbox/schema/task.py -import json -import logging -import requests -import time -from typing import TYPE_CHECKING, Callable, Optional, Dict, Any, List, Union -from labelbox import parser - -from labelbox.exceptions import ResourceNotFoundError -from labelbox.orm.db_object import DbObject -from labelbox.orm.model import Field, Relationship, Entity - -if TYPE_CHECKING: - from labelbox import User - - def lru_cache() -> Callable[..., Callable[..., Dict[str, Any]]]: - pass -else: - from functools import lru_cache - -logger = logging.getLogger(__name__) - - -class Task(DbObject): - """ Represents a server-side process that might take a longer time to process. - Allows the Task state to be updated and checked on the client side. - - Attributes: - updated_at (datetime) - created_at (datetime) - name (str) - status (str) - completion_percentage (float) - - created_by (Relationship): `ToOne` relationship to User - organization (Relationship): `ToOne` relationship to Organization - """ - updated_at = Field.DateTime("updated_at") - created_at = Field.DateTime("created_at") - name = Field.String("name") - status = Field.String("status") - completion_percentage = Field.Float("completion_percentage") - result_url = Field.String("result_url", "result") - errors_url = Field.String("errors_url", "errors") - type = Field.String("type") - metadata = Field.Json("metadata") - _user: Optional["User"] = None - - # Relationships - created_by = Relationship.ToOne("User", False, "created_by") - organization = Relationship.ToOne("Organization") - - def refresh(self) -> None: - """ Refreshes Task data from the server. """ - assert self._user is not None - tasks = list(self._user.created_tasks(where=Task.uid == self.uid)) - if len(tasks) != 1: - raise ResourceNotFoundError(Task, self.uid) - for field in self.fields(): - setattr(self, field.name, getattr(tasks[0], field.name)) - - def wait_till_done(self, - timeout_seconds: float = 300.0, - check_frequency: float = 2.0) -> None: - """ Waits until the task is completed. Periodically queries the server - to update the task attributes. - - Args: - timeout_seconds (float): Maximum time this method can block, in seconds. Defaults to five minutes. - check_frequency (float): Frequency of queries to server to update the task attributes, in seconds. Defaults to two seconds. Minimal value is two seconds. - """ - if check_frequency < 2.0: - raise ValueError( - "Expected check frequency to be two seconds or more") - while timeout_seconds > 0: - if self.status != "IN_PROGRESS": - # self.errors fetches the error content. - # This first condition prevents us from downloading the content for v2 exports - if self.errors_url is not None or self.errors is not None: - logger.warning( - "There are errors present. Please look at `task.errors` for more details" - ) - return - logger.debug("Task.wait_till_done sleeping for %d seconds", - check_frequency) - time.sleep(check_frequency) - timeout_seconds -= check_frequency - self.refresh() - - @property - def errors(self) -> Optional[Dict[str, Any]]: - """ Fetch the error associated with an import task. - """ - if self.name == 'JSON Import': - if self.status == "FAILED": - result = self._fetch_remote_json() - return result["error"] - elif self.status == "COMPLETE": - return self.failed_data_rows - elif self.type == "export-data-rows": - return self._fetch_remote_json(remote_json_field='errors_url') - elif (self.type == "add-data-rows-to-batch" or - self.type == "send-to-task-queue" or - self.type == "send-to-annotate"): - if self.status == "FAILED": - # for these tasks, the error is embedded in the result itself - return json.loads(self.result_url) - return None - - @property - def result(self) -> Union[List[Dict[str, Any]], Dict[str, Any]]: - """ Fetch the result for an import task. - """ - if self.status == "FAILED": - raise ValueError(f"Job failed. Errors : {self.errors}") - else: - result = self._fetch_remote_json() - if self.type == 'export-data-rows': - return result - - return [{ - 'id': data_row['id'], - 'external_id': data_row.get('externalId'), - 'row_data': data_row['rowData'], - 'global_key': data_row.get('globalKey'), - } for data_row in result['createdDataRows']] - - @property - def failed_data_rows(self) -> Optional[Dict[str, Any]]: - """ Fetch data rows which failed to be created for an import task. - """ - result = self._fetch_remote_json() - if len(result.get("errors", [])) > 0: - return result["errors"] - else: - return None - - @property - def created_data_rows(self) -> Optional[Dict[str, Any]]: - """ Fetch data rows which successfully created for an import task. - """ - result = self._fetch_remote_json() - if len(result.get("createdDataRows", [])) > 0: - return result["createdDataRows"] - else: - return None - - @lru_cache() - def _fetch_remote_json(self, - remote_json_field: Optional[str] = None - ) -> Dict[str, Any]: - """ Function for fetching and caching the result data. - """ - - def download_result(remote_json_field: Optional[str], format: str): - url = getattr(self, remote_json_field or 'result_url') - - if url is None: - return None - - response = requests.get(url) - response.raise_for_status() - if format == 'json': - return response.json() - elif format == 'ndjson': - return parser.loads(response.text) - else: - raise ValueError( - "Expected the result format to be either `ndjson` or `json`." - ) - - if self.name == 'JSON Import': - format = 'json' - elif self.type == 'export-data-rows': - format = 'ndjson' - else: - raise ValueError( - "Task result is only supported for `JSON Import` and `export` tasks." - " Download task.result_url manually to access the result for other tasks." - ) - - if self.status != "IN_PROGRESS": - return download_result(remote_json_field, format) - else: - self.wait_till_done(timeout_seconds=600) - if self.status == "IN_PROGRESS": - raise ValueError( - "Job status still in `IN_PROGRESS`. The result is not available. Call task.wait_till_done() with a larger timeout or contact support." - ) - return download_result(remote_json_field, format) - - @staticmethod - def get_task(client, task_id): - user: User = client.get_user() - tasks: List[Task] = list( - user.created_tasks(where=Entity.Task.uid == task_id)) - # Cache user in a private variable as the relationship can't be - # resolved due to server-side limitations (see Task.created_by) - # for more info. - if len(tasks) != 1: - raise ResourceNotFoundError(Entity.Task, {task_id: task_id}) - task: Task = tasks[0] - task._user = user - return task - ----- -labelbox/schema/enums.py -from enum import Enum - - -class BulkImportRequestState(Enum): - """ State of the import job when importing annotations (RUNNING, FAILED, or FINISHED). - - If you are not usinig MEA continue using BulkImportRequest. - AnnotationImports are in beta and will change soon. - - .. list-table:: - :widths: 15 150 - :header-rows: 1 - - * - State - - Description - * - RUNNING - - Indicates that the import job is not done yet. - * - FAILED - - Indicates the import job failed. Check `BulkImportRequest.errors` for more information - * - FINISHED - - Indicates the import job is no longer running. Check `BulkImportRequest.statuses` for more information - """ - RUNNING = "RUNNING" - FAILED = "FAILED" - FINISHED = "FINISHED" - - -class AnnotationImportState(Enum): - """ State of the import job when importing annotations (RUNNING, FAILED, or FINISHED). - - .. list-table:: - :widths: 15 150 - :header-rows: 1 - - * - State - - Description - * - RUNNING - - Indicates that the import job is not done yet. - * - FAILED - - Indicates the import job failed. Check `AnnotationImport.errors` for more information - * - FINISHED - - Indicates the import job is no longer running. Check `AnnotationImport.statuses` for more information - """ - RUNNING = "RUNNING" - FAILED = "FAILED" - FINISHED = "FINISHED" - - -class CollectionJobStatus(Enum): - """ Status of an asynchronous job over a collection. - - * - State - - Description - * - SUCCESS - - Indicates job has successfully processed entire collection of data - * - PARTIAL SUCCESS - - Indicates some data in the collection has succeeded and other data have failed - * - FAILURE - - Indicates job has failed to process entire collection of data - """ - SUCCESS = "SUCCESS" - PARTIAL_SUCCESS = "PARTIAL SUCCESS" - FAILURE = "FAILURE" ----- -labelbox/schema/project_resource_tag.py -from labelbox.orm.db_object import DbObject, Updateable -from labelbox.orm.model import Field, Relationship - - -class ProjectResourceTag(DbObject, Updateable): - """ Project resource tag to associate ProjectResourceTag to Project. - - Attributes: - resourceTagId (str) - projectId (str) - - resource_tag (Relationship): `ToOne` relationship to ResourceTag - """ - - resource_tag_id = Field.ID("resource_tag_id") - project_id = Field.ID("project_id") - ----- -labelbox/schema/task_queue.py -from labelbox.orm.db_object import DbObject -from labelbox.orm.model import Field - - -class TaskQueue(DbObject): - """ - a task queue - - Attributes - name - description - queue_type - data_row_count - - Relationships - project - organization - pass_queue - fail_queue - """ - - name = Field.String("name") - description = Field.String("description") - queue_type = Field.String("queue_type") - data_row_count = Field.Int("data_row_count") - - def __init__(self, client, *args, **kwargs): - super().__init__(client, *args, **kwargs) - ----- -labelbox/schema/benchmark.py -from labelbox.orm.db_object import DbObject -from labelbox.orm.model import Field, Relationship - - -class Benchmark(DbObject): - """ Represents a benchmark label. - - The Benchmarks tool works by interspersing data to be labeled, for - which there is a benchmark label, to each person labeling. These - labeled data are compared against their respective benchmark and an - accuracy score between 0 and 100 percent is calculated. - - Attributes: - created_at (datetime) - last_activity (datetime) - average_agreement (float) - completed_count (int) - - created_by (Relationship): `ToOne` relationship to User - reference_label (Relationship): `ToOne` relationship to Label - """ - created_at = Field.DateTime("created_at") - created_by = Relationship.ToOne("User", False, "created_by") - last_activity = Field.DateTime("last_activity") - average_agreement = Field.Float("average_agreement") - completed_count = Field.Int("completed_count") - - reference_label = Relationship.ToOne("Label", False, "reference_label") - - def delete(self) -> None: - label_param = "labelId" - query_str = """mutation DeleteBenchmarkPyApi($%s: ID!) { - deleteBenchmark(where: {labelId: $%s}) {id}} """ % (label_param, - label_param) - self.client.execute(query_str, - {label_param: self.reference_label().uid}) - ----- -labelbox/schema/slice.py -from dataclasses import dataclass -from typing import Optional -import warnings -from labelbox.orm.db_object import DbObject, experimental -from labelbox.orm.model import Field -from labelbox.pagination import PaginatedCollection -from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params -from labelbox.schema.export_task import ExportTask -from labelbox.schema.identifiable import GlobalKey, UniqueId -from labelbox.schema.task import Task - - -class Slice(DbObject): - """ - A Slice is a saved set of filters (saved query). - This is an abstract class and should not be instantiated. - - Attributes: - name (datetime) - description (datetime) - created_at (datetime) - updated_at (datetime) - filter (json) - """ - - name = Field.String("name") - description = Field.String("description") - created_at = Field.DateTime("created_at") - updated_at = Field.DateTime("updated_at") - filter = Field.Json("filter") - - @dataclass - class DataRowIdAndGlobalKey: - id: UniqueId - global_key: Optional[GlobalKey] - - def __init__(self, id: str, global_key: Optional[str]): - self.id = UniqueId(id) - self.global_key = GlobalKey(global_key) if global_key else None - - def to_hash(self): - return { - "id": self.id.key, - "global_key": self.global_key.key if self.global_key else None - } - - -class CatalogSlice(Slice): - """ - Represents a Slice used for filtering data rows in Catalog. - """ - - def get_data_row_ids(self) -> PaginatedCollection: - """ - Fetches all data row ids that match this Slice - - Returns: - A PaginatedCollection of mapping of data row ids to global keys - """ - - warnings.warn( - "get_data_row_ids will be deprecated. Use get_data_row_identifiers instead" - ) - - query_str = """ - query getDataRowIdsBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) { - getDataRowIdsBySavedQuery(input: { - savedQueryId: $id, - after: $from - first: $first - }) { - totalCount - nodes - pageInfo { - endCursor - hasNextPage - } - } - } - """ - return PaginatedCollection( - client=self.client, - query=query_str, - params={'id': str(self.uid)}, - dereferencing=['getDataRowIdsBySavedQuery', 'nodes'], - obj_class=lambda _, data_row_id: data_row_id, - cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor']) - - def get_data_row_identifiers(self) -> PaginatedCollection: - """ - Fetches all data row ids and global keys (where defined) that match this Slice - - Returns: - A PaginatedCollection of Slice.DataRowIdAndGlobalKey - """ - query_str = """ - query getDataRowIdenfifiersBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) { - getDataRowIdentifiersBySavedQuery(input: { - savedQueryId: $id, - after: $from - first: $first - }) { - totalCount - nodes - { - id - globalKey - } - pageInfo { - endCursor - hasNextPage - } - } - } - """ - return PaginatedCollection( - client=self.client, - query=query_str, - params={'id': str(self.uid)}, - dereferencing=['getDataRowIdentifiersBySavedQuery', 'nodes'], - obj_class=lambda _, data_row_id_and_gk: Slice.DataRowIdAndGlobalKey( - data_row_id_and_gk.get('id'), - data_row_id_and_gk.get('globalKey', None)), - cursor_path=[ - 'getDataRowIdentifiersBySavedQuery', 'pageInfo', 'endCursor' - ]) - - @experimental - def export(self, - task_name: Optional[str] = None, - params: Optional[CatalogExportParams] = None) -> ExportTask: - """ - Creates a slice export task with the given params and returns the task. - >>> slice = client.get_catalog_slice("SLICE_ID") - >>> task = slice.export( - >>> params={"performance_details": False, "label_details": True} - >>> ) - >>> task.wait_till_done() - >>> task.result - """ - task = self._export(task_name, params, streamable=True) - return ExportTask(task) - - def export_v2( - self, - task_name: Optional[str] = None, - params: Optional[CatalogExportParams] = None, - ) -> Task: - """ - Creates a slice export task with the given params and returns the task. - >>> slice = client.get_catalog_slice("SLICE_ID") - >>> task = slice.export_v2( - >>> params={"performance_details": False, "label_details": True} - >>> ) - >>> task.wait_till_done() - >>> task.result - """ - return self._export(task_name, params) - - def _export( - self, - task_name: Optional[str] = None, - params: Optional[CatalogExportParams] = None, - streamable: bool = False, - ) -> Task: - _params = params or CatalogExportParams({ - "attachments": False, - "metadata_fields": False, - "data_row_details": False, - "project_details": False, - "performance_details": False, - "label_details": False, - "media_type_override": None, - "model_run_ids": None, - "project_ids": None, - "interpolated_frames": False, - "all_projects": False, - "all_model_runs": False, - }) - validate_catalog_export_params(_params) - - mutation_name = "exportDataRowsInSlice" - create_task_query_str = ( - f"mutation {mutation_name}PyApi" - f"($input: ExportDataRowsInSliceInput!)" - f"{{{mutation_name}(input: $input){{taskId}}}}") - - media_type_override = _params.get('media_type_override', None) - query_params = { - "input": { - "taskName": task_name, - "filters": { - "sliceId": self.uid - }, - "params": { - "mediaTypeOverride": - media_type_override.value - if media_type_override is not None else None, - "includeAttachments": - _params.get('attachments', False), - "includeMetadata": - _params.get('metadata_fields', False), - "includeDataRowDetails": - _params.get('data_row_details', False), - "includeProjectDetails": - _params.get('project_details', False), - "includePerformanceDetails": - _params.get('performance_details', False), - "includeLabelDetails": - _params.get('label_details', False), - "includeInterpolatedFrames": - _params.get('interpolated_frames', False), - "projectIds": - _params.get('project_ids', None), - "modelRunIds": - _params.get('model_run_ids', None), - "allProjects": - _params.get('all_projects', False), - "allModelRuns": - _params.get('all_model_runs', False), - }, - "streamable": streamable, - } - } - - res = self.client.execute(create_task_query_str, - query_params, - error_log_key="errors") - res = res[mutation_name] - task_id = res["taskId"] - return Task.get_task(self.client, task_id) - - -class ModelSlice(Slice): - """ - Represents a Slice used for filtering data rows in Model. - """ - - @classmethod - def query_str(cls): - query_str = """ - query getDataRowIdenfifiersBySavedModelQueryPyApi($id: ID!, $modelRunId: ID, $from: DataRowIdentifierCursorInput, $first: Int!) { - getDataRowIdentifiersBySavedModelQuery(input: { - savedQueryId: $id, - modelRunId: $modelRunId, - after: $from - first: $first - }) { - totalCount - nodes - { - id - globalKey - } - pageInfo { - endCursor { - dataRowId - globalKey - } - hasNextPage - } - } - } - """ - return query_str - - def get_data_row_ids(self, model_run_id: str) -> PaginatedCollection: - """ - Fetches all data row ids that match this Slice - - Params - model_run_id: str, required, uid or cuid of model run - - Returns: - A PaginatedCollection of data row ids - """ - return PaginatedCollection( - client=self.client, - query=ModelSlice.query_str(), - params={ - 'id': str(self.uid), - 'modelRunId': model_run_id - }, - dereferencing=['getDataRowIdentifiersBySavedModelQuery', 'nodes'], - obj_class=lambda _, data_row_id_and_gk: data_row_id_and_gk.get('id' - ), - cursor_path=[ - 'getDataRowIdentifiersBySavedModelQuery', 'pageInfo', - 'endCursor' - ]) - - def get_data_row_identifiers(self, - model_run_id: str) -> PaginatedCollection: - """ - Fetches all data row ids and global keys (where defined) that match this Slice - - Params: - model_run_id : str, required, uid or cuid of model run - - Returns: - A PaginatedCollection of Slice.DataRowIdAndGlobalKey - """ - return PaginatedCollection( - client=self.client, - query=ModelSlice.query_str(), - params={ - 'id': str(self.uid), - 'modelRunId': model_run_id - }, - dereferencing=['getDataRowIdentifiersBySavedModelQuery', 'nodes'], - obj_class=lambda _, data_row_id_and_gk: Slice.DataRowIdAndGlobalKey( - data_row_id_and_gk.get('id'), - data_row_id_and_gk.get('globalKey', None)), - cursor_path=[ - 'getDataRowIdentifiersBySavedModelQuery', 'pageInfo', - 'endCursor' - ]) - ----- -labelbox/schema/webhook.py -import logging -from enum import Enum -from typing import Iterable, List - -from labelbox.orm import query -from labelbox.orm.db_object import DbObject, Updateable -from labelbox.orm.model import Entity, Field, Relationship - -logger = logging.getLogger(__name__) - - -class Webhook(DbObject, Updateable): - """ Represents a server-side rule for sending notifications to a web-server - whenever one of several predefined actions happens within a context of - a Project or an Organization. - - Attributes: - updated_at (datetime) - created_at (datetime) - url (str) - topics (str): LABEL_CREATED, LABEL_UPDATED, LABEL_DELETED - status (str): ACTIVE, INACTIVE, REVOKED - - """ - - class Status(Enum): - ACTIVE = "ACTIVE" - INACTIVE = "INACTIVE" - REVOKED = "REVOKED" - - class Topic(Enum): - LABEL_CREATED = "LABEL_CREATED" - LABEL_UPDATED = "LABEL_UPDATED" - LABEL_DELETED = "LABEL_DELETED" - WORKFLOW_ACTION = "WORKFLOW_ACTION" - - # For backwards compatibility - for topic in Status: - vars()[topic.name] = topic.value - - for status in Topic: - vars()[status.name] = status.value - - updated_at = Field.DateTime("updated_at") - created_at = Field.DateTime("created_at") - url = Field.String("url") - topics = Field.String("topics") - status = Field.String("status") - - created_by = Relationship.ToOne("User", False, "created_by") - organization = Relationship.ToOne("Organization") - project = Relationship.ToOne("Project") - - @staticmethod - def create(client, topics, url, secret, project) -> "Webhook": - """ Creates a Webhook. - - Args: - client (Client): The Labelbox client used to connect - to the server. - topics (list of str): A list of topics this Webhook should - get notifications for. Must be one of Webhook.Topic - url (str): The URL to which notifications should be sent - by the Labelbox server. - secret (str): A secret key used for signing notifications. - project (Project or None): The project for which notifications - should be sent. If None notifications are sent for all - events in your organization. - Returns: - A newly created Webhook. - - Raises: - ValueError: If the topic is not one of Topic or status is not one of Status - - Information on configuring your server can be found here (this is where the url points to and the secret is set). - https://docs.labelbox.com/reference/webhook - - """ - if not secret: - raise ValueError("Secret must be a non-empty string.") - if not topics: - raise ValueError("Topics must be a non-empty list.") - if not url: - raise ValueError("URL must be a non-empty string.") - Webhook.validate_topics(topics) - - project_str = "" if project is None \ - else ("project:{id:\"%s\"}," % project.uid) - - query_str = """mutation CreateWebhookPyApi { - createWebhook(data:{%s topics:{set:[%s]}, url:"%s", secret:"%s" }){%s} - } """ % (project_str, " ".join(topics), url, secret, - query.results_query_part(Entity.Webhook)) - - return Webhook(client, client.execute(query_str)["createWebhook"]) - - @staticmethod - def validate_topics(topics) -> None: - if isinstance(topics, str) or not isinstance(topics, Iterable): - raise TypeError( - f"Topics must be List[Webhook.Topic]. Found `{topics}`") - - for topic in topics: - Webhook.validate_value(topic, Webhook.Topic) - - @staticmethod - def validate_value(value, enum) -> None: - supported_values = {item.value for item in enum} - if value not in supported_values: - raise ValueError( - f"Value `{value}` does not exist in supported values. Expected one of {supported_values}" - ) - - def delete(self) -> None: - """ - Deletes the webhook - """ - self.update(status=self.Status.INACTIVE.value) - - def update(self, topics=None, url=None, status=None): - """ Updates the Webhook. - - Args: - topics (Optional[List[Topic]]): The new topics. - url Optional[str): The new URL value. - status (Optional[Status]): The new status. - If an argument is set to None then no updates will be made to that field. - - """ - - # Webhook has a custom `update` function due to custom types - # in `status` and `topics` fields. - - if topics is not None: - self.validate_topics(topics) - - if status is not None: - self.validate_value(status, self.Status) - - topics_str = "" if topics is None \ - else "topics: {set: [%s]}" % " ".join(topics) - url_str = "" if url is None else "url: \"%s\"" % url - status_str = "" if status is None else "status: %s" % status - - query_str = """mutation UpdateWebhookPyApi { - updateWebhook(where: {id: "%s"} data:{%s}){%s}} """ % ( - self.uid, ", ".join(filter(None, - (topics_str, url_str, status_str))), - query.results_query_part(Entity.Webhook)) - - self._set_field_values(self.client.execute(query_str)["updateWebhook"]) - ----- -labelbox/schema/data_row_metadata.py -# type: ignore -from datetime import datetime -from copy import deepcopy -from enum import Enum -from itertools import chain -import warnings - -from typing import List, Optional, Dict, Union, Callable, Type, Any, Generator, overload - -from labelbox import pydantic_compat -from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds -from labelbox.schema.identifiable import UniqueId, GlobalKey - -from labelbox.schema.ontology import SchemaId -from labelbox.utils import _CamelCaseMixin, camel_case, format_iso_datetime, format_iso_from_string - - -class DataRowMetadataKind(Enum): - number = "CustomMetadataNumber" - datetime = "CustomMetadataDateTime" - enum = "CustomMetadataEnum" - string = "CustomMetadataString" - option = "CustomMetadataEnumOption" - embedding = "CustomMetadataEmbedding" - - -# Metadata schema -class DataRowMetadataSchema(pydantic_compat.BaseModel): - uid: SchemaId - name: pydantic_compat.constr(strip_whitespace=True, - min_length=1, - max_length=100) - reserved: bool - kind: DataRowMetadataKind - options: Optional[List["DataRowMetadataSchema"]] - parent: Optional[SchemaId] - - -DataRowMetadataSchema.update_forward_refs() - -Embedding: Type[List[float]] = pydantic_compat.conlist(float, - min_items=128, - max_items=128) -String: Type[str] = pydantic_compat.constr(max_length=4096) - - -# Metadata base class -class DataRowMetadataField(_CamelCaseMixin): - # One of `schema_id` or `name` must be provided. If `schema_id` is not provided, it is - # inferred from `name` - schema_id: Optional[SchemaId] = None - name: Optional[str] = None - # value is of type `Any` so that we do not improperly coerce the value to the wrong tpye - # Additional validation is performed before upload using the schema information - value: Any - - -class DataRowMetadata(_CamelCaseMixin): - global_key: Optional[str] = None - data_row_id: Optional[str] = None - fields: List[DataRowMetadataField] - - -class DeleteDataRowMetadata(_CamelCaseMixin): - data_row_id: Union[str, UniqueId, GlobalKey] - fields: List[SchemaId] - - class Config: - arbitrary_types_allowed = True - - -class DataRowMetadataBatchResponse(_CamelCaseMixin): - global_key: Optional[str] = None - data_row_id: Optional[str] = None - error: Optional[str] = None - fields: List[Union[DataRowMetadataField, SchemaId]] - - -# --- Batch GraphQL Objects --- -# Don't want to crowd the name space with internals - - -# Bulk upsert values -class _UpsertDataRowMetadataInput(_CamelCaseMixin): - schema_id: str - value: Any - - -# Batch of upsert values for a datarow -class _UpsertBatchDataRowMetadata(_CamelCaseMixin): - global_key: Optional[str] = None - data_row_id: Optional[str] = None - fields: List[_UpsertDataRowMetadataInput] - - -class _DeleteBatchDataRowMetadata(_CamelCaseMixin): - data_row_identifier: Union[UniqueId, GlobalKey] - schema_ids: List[SchemaId] - - class Config: - arbitrary_types_allowed = True - alias_generator = camel_case - - def dict(self, *args, **kwargs): - res = super().dict(*args, **kwargs) - if 'data_row_identifier' in res.keys(): - key = 'data_row_identifier' - id_type_key = 'id_type' - else: - key = 'dataRowIdentifier' - id_type_key = 'idType' - data_row_identifier = res.pop(key) - res[key] = { - "id": data_row_identifier.key, - id_type_key: data_row_identifier.id_type - } - return res - - -_BatchInputs = Union[List[_UpsertBatchDataRowMetadata], - List[_DeleteBatchDataRowMetadata]] -_BatchFunction = Callable[[_BatchInputs], List[DataRowMetadataBatchResponse]] - - -class _UpsertCustomMetadataSchemaEnumOptionInput(_CamelCaseMixin): - id: Optional[SchemaId] - name: pydantic_compat.constr(strip_whitespace=True, - min_length=1, - max_length=100) - kind: str - - -class _UpsertCustomMetadataSchemaInput(_CamelCaseMixin): - id: Optional[SchemaId] - name: pydantic_compat.constr(strip_whitespace=True, - min_length=1, - max_length=100) - kind: str - options: Optional[List[_UpsertCustomMetadataSchemaEnumOptionInput]] - - -class DataRowMetadataOntology: - """ Ontology for data row metadata - - Metadata provides additional context for a data rows. Metadata is broken into two classes - reserved and custom. Reserved fields are defined by Labelbox and used for creating - specific experiences in the platform. - - >>> mdo = client.get_data_row_metadata_ontology() - - """ - - def __init__(self, client): - - self._client = client - self._batch_size = 50 # used for uploads and deletes - - self._raw_ontology = self._get_ontology() - self._build_ontology() - - def _build_ontology(self): - # all fields - self.fields = self._parse_ontology(self._raw_ontology) - self.fields_by_id = self._make_id_index(self.fields) - - # reserved fields - self.reserved_fields: List[DataRowMetadataSchema] = [ - f for f in self.fields if f.reserved - ] - self.reserved_by_id = self._make_id_index(self.reserved_fields) - self.reserved_by_name: Dict[str, Union[DataRowMetadataSchema, Dict[ - str, DataRowMetadataSchema]]] = self._make_name_index( - self.reserved_fields) - self.reserved_by_name_normalized: Dict[ - str, DataRowMetadataSchema] = self._make_normalized_name_index( - self.reserved_fields) - - # custom fields - self.custom_fields: List[DataRowMetadataSchema] = [ - f for f in self.fields if not f.reserved - ] - self.custom_by_id = self._make_id_index(self.custom_fields) - self.custom_by_name: Dict[str, Union[DataRowMetadataSchema, Dict[ - str, - DataRowMetadataSchema]]] = self._make_name_index(self.custom_fields) - self.custom_by_name_normalized: Dict[ - str, DataRowMetadataSchema] = self._make_normalized_name_index( - self.custom_fields) - - @staticmethod - def _lookup_in_index_by_name(reserved_index, custom_index, name): - # search through reserved names first - if name in reserved_index: - return reserved_index[name] - elif name in custom_index: - return custom_index[name] - else: - raise KeyError(f"There is no metadata with name '{name}'") - - def get_by_name( - self, name: str - ) -> Union[DataRowMetadataSchema, Dict[str, DataRowMetadataSchema]]: - """ Get metadata by name - - >>> mdo.get_by_name(name) - - Args: - name (str): Name of metadata schema - - Returns: - Metadata schema as `DataRowMetadataSchema` or dict, in case of Enum metadata - - Raises: - KeyError: When provided name is not presented in neither reserved nor custom metadata list - """ - return self._lookup_in_index_by_name(self.reserved_by_name, - self.custom_by_name, name) - - def _get_by_name_normalized(self, name: str) -> DataRowMetadataSchema: - """ Get metadata by name. For options, it provides the option schema instead of list of - options - """ - # using `normalized` indices to find options by name as well - return self._lookup_in_index_by_name(self.reserved_by_name_normalized, - self.custom_by_name_normalized, - name) - - @staticmethod - def _make_name_index( - fields: List[DataRowMetadataSchema] - ) -> Dict[str, Union[DataRowMetadataSchema, Dict[str, - DataRowMetadataSchema]]]: - index = {} - for f in fields: - if f.options: - index[f.name] = {} - for o in f.options: - index[f.name][o.name] = o - else: - index[f.name] = f - return index - - @staticmethod - def _make_normalized_name_index( - fields: List[DataRowMetadataSchema] - ) -> Dict[str, DataRowMetadataSchema]: - index = {} - for f in fields: - index[f.name] = f - return index - - @staticmethod - def _make_id_index( - fields: List[DataRowMetadataSchema] - ) -> Dict[SchemaId, DataRowMetadataSchema]: - index = {} - for f in fields: - index[f.uid] = f - if f.options: - for o in f.options: - index[o.uid] = o - return index - - def _get_ontology(self) -> List[Dict[str, Any]]: - query = """query GetMetadataOntologyBetaPyApi { - customMetadataOntology { - id - name - kind - reserved - options { - id - kind - name - reserved - } - }} - """ - return self._client.execute(query)["customMetadataOntology"] - - @staticmethod - def _parse_ontology(raw_ontology) -> List[DataRowMetadataSchema]: - fields = [] - copy = deepcopy(raw_ontology) - for schema in copy: - schema["uid"] = schema["id"] - options = None - if schema.get("options"): - options = [] - for option in schema["options"]: - option["uid"] = option["id"] - options.append( - DataRowMetadataSchema(**{ - **option, - **{ - "parent": schema["uid"] - } - })) - schema["options"] = options - fields.append(DataRowMetadataSchema(**schema)) - - return fields - - def refresh_ontology(self): - """ Update the `DataRowMetadataOntology` instance with the latest - metadata ontology schemas - """ - self._raw_ontology = self._get_ontology() - self._build_ontology() - - def create_schema(self, - name: str, - kind: DataRowMetadataKind, - options: List[str] = None) -> DataRowMetadataSchema: - """ Create metadata schema - - >>> mdo.create_schema(name, kind, options) - - Args: - name (str): Name of metadata schema - kind (DataRowMetadataKind): Kind of metadata schema as `DataRowMetadataKind` - options (List[str]): List of Enum options - - Returns: - Created metadata schema as `DataRowMetadataSchema` - - Raises: - KeyError: When provided name is not a valid custom metadata - """ - if not isinstance(kind, DataRowMetadataKind): - raise ValueError(f"kind '{kind}' must be a `DataRowMetadataKind`") - - upsert_schema = _UpsertCustomMetadataSchemaInput(name=name, - kind=kind.value) - if options: - if kind != DataRowMetadataKind.enum: - raise ValueError( - f"Kind '{kind}' must be an Enum, if Enum options are provided" - ) - upsert_enum_options = [ - _UpsertCustomMetadataSchemaEnumOptionInput( - name=o, kind=DataRowMetadataKind.option.value) - for o in options - ] - upsert_schema.options = upsert_enum_options - - return self._upsert_schema(upsert_schema) - - def update_schema(self, name: str, new_name: str) -> DataRowMetadataSchema: - """ Update metadata schema - - >>> mdo.update_schema(name, new_name) - - Args: - name (str): Current name of metadata schema - new_name (str): New name of metadata schema - - Returns: - Updated metadata schema as `DataRowMetadataSchema` - - Raises: - KeyError: When provided name is not a valid custom metadata - """ - schema = self._validate_custom_schema_by_name(name) - upsert_schema = _UpsertCustomMetadataSchemaInput(id=schema.uid, - name=new_name, - kind=schema.kind.value) - if schema.options: - upsert_enum_options = [ - _UpsertCustomMetadataSchemaEnumOptionInput( - id=o.uid, - name=o.name, - kind=DataRowMetadataKind.option.value) - for o in schema.options - ] - upsert_schema.options = upsert_enum_options - - return self._upsert_schema(upsert_schema) - - def update_enum_option(self, name: str, option: str, - new_option: str) -> DataRowMetadataSchema: - """ Update Enum metadata schema option - - >>> mdo.update_enum_option(name, option, new_option) - - Args: - name (str): Name of metadata schema to update - option (str): Name of Enum option to update - new_option (str): New name of Enum option - - Returns: - Updated metadata schema as `DataRowMetadataSchema` - - Raises: - KeyError: When provided name is not a valid custom metadata - """ - schema = self._validate_custom_schema_by_name(name) - if schema.kind != DataRowMetadataKind.enum: - raise ValueError( - f"Updating Enum option is only supported for Enum metadata schema" - ) - valid_options: List[str] = [o.name for o in schema.options] - - if option not in valid_options: - raise ValueError( - f"Enum option '{option}' is not a valid option for Enum '{name}', valid options are: {valid_options}" - ) - upsert_schema = _UpsertCustomMetadataSchemaInput(id=schema.uid, - name=schema.name, - kind=schema.kind.value) - upsert_enum_options = [] - for o in schema.options: - enum_option = _UpsertCustomMetadataSchemaEnumOptionInput( - id=o.uid, name=o.name, kind=o.kind.value) - if enum_option.name == option: - enum_option.name = new_option - upsert_enum_options.append(enum_option) - upsert_schema.options = upsert_enum_options - - return self._upsert_schema(upsert_schema) - - def delete_schema(self, name: str) -> bool: - """ Delete metadata schema - - >>> mdo.delete_schema(name) - - Args: - name: Name of metadata schema to delete - - Returns: - True if deletion is successful, False if unsuccessful - - Raises: - KeyError: When provided name is not a valid custom metadata - """ - schema = self._validate_custom_schema_by_name(name) - query = """mutation DeleteCustomMetadataSchemaPyApi($where: WhereUniqueIdInput!) { - deleteCustomMetadataSchema(schema: $where){ - success - } - }""" - res = self._client.execute(query, {'where': { - 'id': schema.uid - }})['deleteCustomMetadataSchema'] - self.refresh_ontology() - - return res['success'] - - def parse_metadata( - self, unparsed: List[Dict[str, - List[Union[str, - Dict]]]]) -> List[DataRowMetadata]: - """ Parse metadata responses - - >>> mdo.parse_metadata([metadata]) - - Args: - unparsed: An unparsed metadata export - - Returns: - metadata: List of `DataRowMetadata` - """ - parsed = [] - if isinstance(unparsed, dict): - raise ValueError("Pass a list of dictionaries") - - for dr in unparsed: - fields = [] - if "fields" in dr: - fields = self.parse_metadata_fields(dr["fields"]) - parsed.append( - DataRowMetadata(data_row_id=dr["dataRowId"], - global_key=dr["globalKey"], - fields=fields)) - return parsed - - def parse_metadata_fields( - self, unparsed: List[Dict[str, - Dict]]) -> List[DataRowMetadataField]: - """ Parse metadata fields as list of `DataRowMetadataField` - - >>> mdo.parse_metadata_fields([metadata_fields]) - - Args: - unparsed: An unparsed list of metadata represented as a dict containing 'schemaId' and 'value' - - Returns: - metadata: List of `DataRowMetadataField` - """ - parsed = [] - if isinstance(unparsed, dict): - raise ValueError("Pass a list of dictionaries") - - for f in unparsed: - if f["schemaId"] not in self.fields_by_id: - # Fetch latest metadata ontology if metadata can't be found - self.refresh_ontology() - if f["schemaId"] not in self.fields_by_id: - raise ValueError( - f"Schema Id `{f['schemaId']}` not found in ontology") - - schema = self.fields_by_id[f["schemaId"]] - if schema.kind == DataRowMetadataKind.enum: - continue - elif schema.kind == DataRowMetadataKind.option: - field = DataRowMetadataField(schema_id=schema.parent, - value=schema.uid) - elif schema.kind == DataRowMetadataKind.datetime: - field = DataRowMetadataField(schema_id=schema.uid, - value=format_iso_from_string( - f["value"])) - else: - field = DataRowMetadataField(schema_id=schema.uid, - value=f["value"]) - - field.name = schema.name - parsed.append(field) - return parsed - - def bulk_upsert( - self, metadata: List[DataRowMetadata] - ) -> List[DataRowMetadataBatchResponse]: - """Upsert metadata to a list of data rows - - You may specify data row by either data_row_id or global_key - - >>> metadata = DataRowMetadata( - >>> data_row_id="datarow-id", # Alternatively, set global_key="global-key" - >>> fields=[ - >>> DataRowMetadataField(schema_id="schema-id", value="my-message"), - >>> ... - >>> ] - >>> ) - >>> mdo.batch_upsert([metadata]) - - Args: - metadata: List of DataRow Metadata to upsert - - Returns: - list of unsuccessful upserts. - An empty list means the upload was successful. - """ - - if not len(metadata): - raise ValueError("Empty list passed") - - def _batch_upsert( - upserts: List[_UpsertBatchDataRowMetadata] - ) -> List[DataRowMetadataBatchResponse]: - query = """mutation UpsertDataRowMetadataBetaPyApi($metadata: [DataRowCustomMetadataBatchUpsertInput!]!) { - upsertDataRowCustomMetadata(data: $metadata){ - globalKey - dataRowId - error - fields { - value - schemaId - } - } - }""" - res = self._client.execute( - query, {"metadata": upserts})['upsertDataRowCustomMetadata'] - return [ - DataRowMetadataBatchResponse(global_key=r['globalKey'], - data_row_id=r['dataRowId'], - error=r['error'], - fields=self.parse_metadata( - [r])[0].fields) for r in res - ] - - items = [] - for m in metadata: - items.append( - _UpsertBatchDataRowMetadata( - global_key=m.global_key, - data_row_id=m.data_row_id, - fields=list( - chain.from_iterable( - self._parse_upsert(f, m.data_row_id) - for f in m.fields))).dict(by_alias=True)) - res = _batch_operations(_batch_upsert, items, self._batch_size) - return res - - def bulk_delete( - self, deletes: List[DeleteDataRowMetadata] - ) -> List[DataRowMetadataBatchResponse]: - """ Delete metadata from a datarow by specifiying the fields you want to remove - - >>> delete = DeleteDataRowMetadata( - >>> data_row_id=UniqueId("datarow-id"), - >>> fields=[ - >>> "schema-id-1", - >>> "schema-id-2" - >>> ... - >>> ] - >>> ) - >>> mdo.batch_delete([metadata]) - - >>> delete = DeleteDataRowMetadata( - >>> data_row_id=GlobalKey("global-key"), - >>> fields=[ - >>> "schema-id-1", - >>> "schema-id-2" - >>> ... - >>> ] - >>> ) - >>> mdo.batch_delete([metadata]) - - >>> delete = DeleteDataRowMetadata( - >>> data_row_id="global-key", - >>> fields=[ - >>> "schema-id-1", - >>> "schema-id-2" - >>> ... - >>> ] - >>> ) - >>> mdo.batch_delete([metadata]) - - - Args: - deletes: Data row and schema ids to delete - For data row, we support UniqueId, str, and GlobalKey. - If you pass a str, we will assume it is a UniqueId - Do not pass a mix of data row ids and global keys in the same list - - Returns: - list of unsuccessful deletions. - An empty list means all data rows were successfully deleted. - - """ - - if not len(deletes): - raise ValueError("The 'deletes' list cannot be empty.") - - passed_strings = False - for i, delete in enumerate(deletes): - if isinstance(delete.data_row_id, str): - passed_strings = True - deletes[i] = DeleteDataRowMetadata(data_row_id=UniqueId( - delete.data_row_id), - fields=delete.fields) - elif isinstance(delete.data_row_id, UniqueId): - continue - elif isinstance(delete.data_row_id, GlobalKey): - continue - else: - raise ValueError( - f"Invalid data row identifier type '{type(delete.data_row_id)}' for '{delete.data_row_id}'" - ) - - if passed_strings: - warnings.warn( - "Using string for data row id will be deprecated. Please use " - "UniqueId instead.") - - def _batch_delete( - deletes: List[_DeleteBatchDataRowMetadata] - ) -> List[DataRowMetadataBatchResponse]: - query = """mutation DeleteDataRowMetadataBetaPyApi($deletes: [DataRowIdentifierCustomMetadataBatchDeleteInput!]) { - deleteDataRowCustomMetadata(dataRowIdentifiers: $deletes) { - dataRowId - error - fields { - value - schemaId - } - } - } - """ - res = self._client.execute( - query, {"deletes": deletes})['deleteDataRowCustomMetadata'] - failures = [] - for dr in res: - dr['fields'] = [f['schemaId'] for f in dr['fields']] - failures.append(DataRowMetadataBatchResponse(**dr)) - return failures - - items = [self._validate_delete(m) for m in deletes] - return _batch_operations(_batch_delete, - items, - batch_size=self._batch_size) - - @overload - def bulk_export(self, data_row_ids: List[str]) -> List[DataRowMetadata]: - pass - - @overload - def bulk_export(self, - data_row_ids: DataRowIdentifiers) -> List[DataRowMetadata]: - pass - - def bulk_export(self, data_row_ids) -> List[DataRowMetadata]: - """ Exports metadata for a list of data rows - - >>> mdo.bulk_export([data_row.uid for data_row in data_rows]) - - Args: - data_row_ids: List of data data rows to fetch metadata for. This can be a list of strings or a DataRowIdentifiers object - DataRowIdentifier objects are lists of ids or global keys. A DataIdentifier object can be a UniqueIds or GlobalKeys class. - Returns: - A list of DataRowMetadata. - There will be one DataRowMetadata for each data_row_id passed in. - This is true even if the data row does not have any meta data. - Data rows without metadata will have empty `fields`. - - """ - if not len(data_row_ids): - raise ValueError("Empty list passed") - - if isinstance(data_row_ids, - list) and len(data_row_ids) > 0 and isinstance( - data_row_ids[0], str): - data_row_ids = UniqueIds(data_row_ids) - warnings.warn("Using data row ids will be deprecated. Please use " - "UniqueIds or GlobalKeys instead.") - - def _bulk_export( - _data_row_ids: DataRowIdentifiers) -> List[DataRowMetadata]: - query = """query dataRowCustomMetadataPyApi($dataRowIdentifiers: DataRowCustomMetadataDataRowIdentifiersInput) { - dataRowCustomMetadata(where: {dataRowIdentifiers : $dataRowIdentifiers}) { - dataRowId - globalKey - fields { - value - schemaId - } - } - } - """ - return self.parse_metadata( - self._client.execute( - query, { - "dataRowIdentifiers": { - "ids": [id for id in _data_row_ids], - "idType": _data_row_ids.id_type - } - })['dataRowCustomMetadata']) - - return _batch_operations(_bulk_export, - data_row_ids, - batch_size=self._batch_size) - - def parse_upsert_metadata(self, metadata_fields) -> List[Dict[str, Any]]: - """ Converts either `DataRowMetadataField` or a dictionary representation - of `DataRowMetadataField` into a validated, flattened dictionary of - metadata fields that are used to create data row metadata. Used - internally in `Dataset.create_data_rows()` - - Args: - metadata_fields: List of `DataRowMetadataField` or a dictionary representation - of `DataRowMetadataField` - Returns: - List of dictionaries representing a flattened view of metadata fields - """ - - def _convert_metadata_field(metadata_field): - if isinstance(metadata_field, DataRowMetadataField): - return metadata_field - elif isinstance(metadata_field, dict): - if not "value" in metadata_field: - raise ValueError( - f"Custom metadata field '{metadata_field}' must have a 'value' key" - ) - if not "schema_id" in metadata_field and not "name" in metadata_field: - raise ValueError( - f"Custom metadata field '{metadata_field}' must have either 'schema_id' or 'name' key" - ) - return DataRowMetadataField( - schema_id=metadata_field.get("schema_id"), - name=metadata_field.get("name"), - value=metadata_field["value"]) - else: - raise ValueError( - f"Metadata field '{metadata_field}' is neither 'DataRowMetadataField' type or a dictionary" - ) - - # Convert all metadata fields to DataRowMetadataField type - metadata_fields = [_convert_metadata_field(m) for m in metadata_fields] - parsed_metadata = list( - chain.from_iterable(self._parse_upsert(m) for m in metadata_fields)) - return [m.dict(by_alias=True) for m in parsed_metadata] - - def _upsert_schema( - self, upsert_schema: _UpsertCustomMetadataSchemaInput - ) -> DataRowMetadataSchema: - query = """mutation UpsertCustomMetadataSchemaPyApi($data: UpsertCustomMetadataSchemaInput!) { - upsertCustomMetadataSchema(data: $data){ - id - name - kind - options { - id - name - kind - } - } - }""" - res = self._client.execute( - query, {"data": upsert_schema.dict(exclude_none=True) - })['upsertCustomMetadataSchema'] - self.refresh_ontology() - return _parse_metadata_schema(res) - - def _load_option_by_name(self, metadatum: DataRowMetadataField): - is_value_a_valid_schema_id = metadatum.value in self.fields_by_id - if not is_value_a_valid_schema_id: - metadatum_by_name = self.get_by_name(metadatum.name) - if metadatum.value not in metadatum_by_name: - raise KeyError( - f"There is no enum option by name '{metadatum.value}' for enum name '{metadatum.name}'" - ) - metadatum.value = metadatum_by_name[metadatum.value].uid - - def _load_schema_id_by_name(self, metadatum: DataRowMetadataField): - """ - Loads schema id by name for a metadata field including options schema id. - """ - if metadatum.name is None: - return - - if metadatum.schema_id is None: - schema = self._get_by_name_normalized(metadatum.name) - metadatum.schema_id = schema.uid - if schema.options: - self._load_option_by_name(metadatum) - - def _parse_upsert( - self, - metadatum: DataRowMetadataField, - data_row_id: Optional[str] = None - ) -> List[_UpsertDataRowMetadataInput]: - """Format for metadata upserts to GQL""" - - self._load_schema_id_by_name(metadatum) - - if metadatum.schema_id not in self.fields_by_id: - # Fetch latest metadata ontology if metadata can't be found - self.refresh_ontology() - if metadatum.schema_id not in self.fields_by_id: - raise ValueError( - f"Schema Id `{metadatum.schema_id}` not found in ontology") - - schema = self.fields_by_id[metadatum.schema_id] - try: - if schema.kind == DataRowMetadataKind.datetime: - parsed = _validate_parse_datetime(metadatum) - elif schema.kind == DataRowMetadataKind.string: - parsed = _validate_parse_text(metadatum) - elif schema.kind == DataRowMetadataKind.number: - parsed = _validate_parse_number(metadatum) - elif schema.kind == DataRowMetadataKind.embedding: - parsed = _validate_parse_embedding(metadatum) - elif schema.kind == DataRowMetadataKind.enum: - parsed = _validate_enum_parse(schema, metadatum) - elif schema.kind == DataRowMetadataKind.option: - raise ValueError( - "An Option id should not be set as the Schema id") - else: - raise ValueError(f"Unknown type: {schema}") - except ValueError as e: - error_str = f"Could not validate metadata [{metadatum}]" - if data_row_id: - error_str += f", data_row_id='{data_row_id}'" - raise ValueError(f"{error_str}. Reason: {e}") - - return [_UpsertDataRowMetadataInput(**p) for p in parsed] - - def _validate_delete(self, delete: DeleteDataRowMetadata): - if not len(delete.fields): - raise ValueError(f"No fields specified for {delete.data_row_id}") - - deletes = set() - for schema_id in delete.fields: - if schema_id not in self.fields_by_id: - # Fetch latest metadata ontology if metadata can't be found - self.refresh_ontology() - if schema_id not in self.fields_by_id: - raise ValueError( - f"Schema Id `{schema_id}` not found in ontology") - - schema = self.fields_by_id[schema_id] - # handle users specifying enums by adding all option enums - if schema.kind == DataRowMetadataKind.enum: - [deletes.add(o.uid) for o in schema.options] - - deletes.add(schema.uid) - - return _DeleteBatchDataRowMetadata( - data_row_identifier=delete.data_row_id, - schema_ids=list(delete.fields)).dict(by_alias=True) - - def _validate_custom_schema_by_name(self, - name: str) -> DataRowMetadataSchema: - if name not in self.custom_by_name_normalized: - # Fetch latest metadata ontology if metadata can't be found - self.refresh_ontology() - if name not in self.custom_by_name_normalized: - raise KeyError(f"'{name}' is not a valid custom metadata") - - return self.custom_by_name_normalized[name] - - -def _batch_items(iterable: List[Any], size: int) -> Generator[Any, None, None]: - l = len(iterable) - for ndx in range(0, l, size): - yield iterable[ndx:min(ndx + size, l)] - - -def _batch_operations( - batch_function: _BatchFunction, - items: List, - batch_size: int = 100, -): - response = [] - - for batch in _batch_items(items, batch_size): - response += batch_function(batch) - return response - - -def _validate_parse_embedding( - field: DataRowMetadataField -) -> List[Dict[str, Union[SchemaId, Embedding]]]: - - if isinstance(field.value, list): - if not (Embedding.min_items <= len(field.value) <= Embedding.max_items): - raise ValueError( - "Embedding length invalid. " - "Must have length within the interval " - f"[{Embedding.min_items},{Embedding.max_items}]. Found {len(field.value)}" - ) - field.value = [float(x) for x in field.value] - else: - raise ValueError( - f"Expected a list for embedding. Found {type(field.value)}") - return [field.dict(by_alias=True)] - - -def _validate_parse_number( - field: DataRowMetadataField -) -> List[Dict[str, Union[SchemaId, str, float, int]]]: - field.value = float(field.value) - return [field.dict(by_alias=True)] - - -def _validate_parse_datetime( - field: DataRowMetadataField) -> List[Dict[str, Union[SchemaId, str]]]: - if isinstance(field.value, str): - field.value = format_iso_from_string(field.value) - elif not isinstance(field.value, datetime): - raise TypeError( - f"Value for datetime fields must be either a string or datetime object. Found {type(field.value)}" - ) - - return [{ - "schemaId": field.schema_id, - "value": format_iso_datetime(field.value) - }] - - -def _validate_parse_text( - field: DataRowMetadataField) -> List[Dict[str, Union[SchemaId, str]]]: - if not isinstance(field.value, str): - raise ValueError( - f"Expected a string type for the text field. Found {type(field.value)}" - ) - - if len(field.value) > String.max_length: - raise ValueError( - f"String fields cannot exceed {String.max_length} characters.") - - return [field.dict(by_alias=True)] - - -def _validate_enum_parse( - schema: DataRowMetadataSchema, - field: DataRowMetadataField) -> List[Dict[str, Union[SchemaId, dict]]]: - if schema.options: - if field.value not in {o.uid for o in schema.options}: - raise ValueError( - f"Option `{field.value}` not found for {field.schema_id}") - else: - raise ValueError("Incorrectly specified enum schema") - - return [{ - "schemaId": field.schema_id, - "value": {} - }, { - "schemaId": field.value, - "value": {} - }] - - -def _parse_metadata_schema( - unparsed: Dict[str, Union[str, List]]) -> DataRowMetadataSchema: - uid = unparsed['id'] - name = unparsed['name'] - kind = DataRowMetadataKind(unparsed['kind']) - options = [ - DataRowMetadataSchema(uid=o['id'], - name=o['name'], - reserved=False, - kind=DataRowMetadataKind.option, - parent=uid) for o in unparsed['options'] - ] - return DataRowMetadataSchema(uid=uid, - name=name, - reserved=False, - kind=kind, - options=options or None) - ----- -labelbox/schema/annotation_import.py -import functools -import json -import logging -import os -import time -from typing import Any, BinaryIO, Dict, List, Union, TYPE_CHECKING, cast -from collections import defaultdict - -from google.api_core import retry -from labelbox import parser -import requests -from tqdm import tqdm # type: ignore - -import labelbox -from labelbox.orm import query -from labelbox.orm.db_object import DbObject -from labelbox.orm.model import Field, Relationship -from labelbox.utils import is_exactly_one_set -from labelbox.schema.confidence_presence_checker import LabelsConfidencePresenceChecker -from labelbox.schema.enums import AnnotationImportState -from labelbox.schema.serialization import serialize_labels - -if TYPE_CHECKING: - from labelbox.types import Label - -NDJSON_MIME_TYPE = "application/x-ndjson" -ANNOTATION_PER_LABEL_LIMIT = 5000 - -logger = logging.getLogger(__name__) - - -class AnnotationImport(DbObject): - name = Field.String("name") - state = Field.Enum(AnnotationImportState, "state") - input_file_url = Field.String("input_file_url") - error_file_url = Field.String("error_file_url") - status_file_url = Field.String("status_file_url") - progress = Field.String("progress") - - created_by = Relationship.ToOne("User", False, "created_by") - - @property - def inputs(self) -> List[Dict[str, Any]]: - """ - Inputs for each individual annotation uploaded. - This should match the ndjson annotations that you have uploaded. - Returns: - Uploaded ndjson. - * This information will expire after 24 hours. - """ - return self._fetch_remote_ndjson(self.input_file_url) - - @property - def errors(self) -> List[Dict[str, Any]]: - """ - Errors for each individual annotation uploaded. This is a subset of statuses - - Returns: - List of dicts containing error messages. Empty list means there were no errors - See `AnnotationImport.statuses` for more details. - * This information will expire after 24 hours. - """ - self.wait_until_done() - return self._fetch_remote_ndjson(self.error_file_url) - - @property - def statuses(self) -> List[Dict[str, Any]]: - """ - Status for each individual annotation uploaded. - - Returns: - A status for each annotation if the upload is done running. - See below table for more details - - .. list-table:: - :widths: 15 150 - :header-rows: 1 - - * - Field - - Description - * - uuid - - Specifies the annotation for the status row. - * - dataRow - - JSON object containing the Labelbox data row ID for the annotation. - * - status - - Indicates SUCCESS or FAILURE. - * - errors - - An array of error messages included when status is FAILURE. Each error has a name, message and optional (key might not exist) additional_info. - - * This information will expire after 24 hours. - """ - self.wait_until_done() - return self._fetch_remote_ndjson(self.status_file_url) - - def wait_until_done(self, - sleep_time_seconds: int = 10, - show_progress: bool = False) -> None: - """Blocks import job until certain conditions are met. - Blocks until the AnnotationImport.state changes either to - `AnnotationImportState.FINISHED` or `AnnotationImportState.FAILED`, - periodically refreshing object's state. - Args: - sleep_time_seconds (int): a time to block between subsequent API calls - show_progress (bool): should show progress bar - """ - pbar = tqdm(total=100, - bar_format="{n}% |{bar}| [{elapsed}, {rate_fmt}{postfix}]" - ) if show_progress else None - while self.state.value == AnnotationImportState.RUNNING.value: - logger.info(f"Sleeping for {sleep_time_seconds} seconds...") - time.sleep(sleep_time_seconds) - self.__backoff_refresh() - if self.progress and self.progress and pbar: - pbar.update(int(self.progress.replace("%", "")) - pbar.n) - - if pbar: - pbar.update(100 - pbar.n) - pbar.close() - - @retry.Retry(predicate=retry.if_exception_type( - labelbox.exceptions.ApiLimitError, labelbox.exceptions.TimeoutError, - labelbox.exceptions.NetworkError)) - def __backoff_refresh(self) -> None: - self.refresh() - - @functools.lru_cache() - def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]: - """ - Fetches the remote ndjson file and caches the results. - Args: - url (str): Can be any url pointing to an ndjson file. - Returns: - ndjson as a list of dicts. - """ - if self.state == AnnotationImportState.FAILED: - raise ValueError("Import failed.") - - response = requests.get(url) - response.raise_for_status() - return parser.loads(response.text) - - @classmethod - def _create_from_bytes(cls, client, variables, query_str, file_name, - bytes_data) -> Dict[str, Any]: - operations = json.dumps({"variables": variables, "query": query_str}) - data = { - "operations": operations, - "map": (None, json.dumps({file_name: ["variables.file"]})) - } - file_data = (file_name, bytes_data, NDJSON_MIME_TYPE) - files = {file_name: file_data} - return client.execute(data=data, files=files) - - @classmethod - def _get_ndjson_from_objects(cls, objects: Union[List[Dict[str, Any]], - List["Label"]], - object_name: str) -> BinaryIO: - if not isinstance(objects, list): - raise TypeError( - f"{object_name} must be in a form of list. Found {type(objects)}" - ) - - objects = serialize_labels(objects) - cls._validate_data_rows(objects) - - data_str = parser.dumps(objects) - if not data_str: - raise ValueError(f"{object_name} cannot be empty") - - return data_str.encode( - 'utf-8' - ) # NOTICE this method returns bytes, NOT BinaryIO... should have done io.BytesIO(...) but not going to change this at the moment since it works and fools mypy - - def refresh(self) -> None: - """Synchronizes values of all fields with the database. - """ - cls = type(self) - res = cls.from_name(self.client, - self.parent_id, - self.name, - as_json=True) - self._set_field_values(res) - - @classmethod - def _validate_data_rows(cls, objects: List[Dict[str, Any]]): - """ - Validates annotations by checking 'dataRow' is provided - and only one of 'id' or 'globalKey' is provided. - - Shows up to `max_num_errors` errors if invalidated, to prevent - large number of error messages from being printed out - """ - errors = [] - max_num_errors = 100 - labels_per_datarow: Dict[str, Dict[str, int]] = defaultdict( - lambda: defaultdict(int)) - for object in objects: - if 'dataRow' not in object: - errors.append(f"'dataRow' is missing in {object}") - continue - data_row_object = object['dataRow'] - if not is_exactly_one_set(data_row_object.get('id'), - data_row_object.get('globalKey')): - errors.append( - f"Must provide only one of 'id' or 'globalKey' for 'dataRow' in {object}" - ) - else: - data_row_id = data_row_object.get( - 'globalKey') or data_row_object.get('id') - name = object.get('name') - if name: - labels_per_datarow[data_row_id][name] += 1 - for data_row_id, label_annotations in labels_per_datarow.items(): - for label_name, annotations in label_annotations.items(): - if annotations > ANNOTATION_PER_LABEL_LIMIT: - errors.append( - f"Row with id or global key {data_row_id} has {annotations} annotations for label {label_name}.\ - Imports are limited to {ANNOTATION_PER_LABEL_LIMIT} annotations per data row." - ) - if errors: - errors_length = len(errors) - formatted_errors = '\n'.join(errors[:max_num_errors]) - if errors_length > max_num_errors: - logger.warning( - f"Found more than {max_num_errors} errors. Showing first {max_num_errors} error messages..." - ) - raise ValueError( - f"Error while validating annotations. Found {errors_length} annotations with errors. Errors:\n{formatted_errors}" - ) - - @classmethod - def from_name(cls, - client: "labelbox.Client", - parent_id: str, - name: str, - as_json: bool = False): - raise NotImplementedError("Inheriting class must override") - - @property - def parent_id(self) -> str: - raise NotImplementedError("Inheriting class must override") - - -class MEAPredictionImport(AnnotationImport): - model_run_id = Field.String("model_run_id") - - @property - def parent_id(self) -> str: - """ - Identifier for this import. Used to refresh the status - """ - return self.model_run_id - - @classmethod - def create_from_file(cls, client: "labelbox.Client", model_run_id: str, - name: str, path: str) -> "MEAPredictionImport": - """ - Create an MEA prediction import job from a file of annotations - - Args: - client: Labelbox Client for executing queries - model_run_id: Model run to import labels into - name: Name of the import job. Can be used to reference the task later - path: Path to ndjson file containing annotations - Returns: - MEAPredictionImport - """ - if os.path.exists(path): - with open(path, 'rb') as f: - return cls._create_mea_import_from_bytes( - client, model_run_id, name, f, - os.stat(path).st_size) - else: - raise ValueError(f"File {path} is not accessible") - - @classmethod - def create_from_objects( - cls, client: "labelbox.Client", model_run_id: str, name, - predictions: Union[List[Dict[str, Any]], List["Label"]] - ) -> "MEAPredictionImport": - """ - Create an MEA prediction import job from an in memory dictionary - - Args: - client: Labelbox Client for executing queries - model_run_id: Model run to import labels into - name: Name of the import job. Can be used to reference the task later - predictions: List of prediction annotations - Returns: - MEAPredictionImport - """ - data = cls._get_ndjson_from_objects(predictions, 'annotations') - - return cls._create_mea_import_from_bytes(client, model_run_id, name, - data, len(str(data))) - - @classmethod - def create_from_url(cls, client: "labelbox.Client", model_run_id: str, - name: str, url: str) -> "MEAPredictionImport": - """ - Create an MEA prediction import job from a url - The url must point to a file containing prediction annotations. - - Args: - client: Labelbox Client for executing queries - model_run_id: Model run to import labels into - name: Name of the import job. Can be used to reference the task later - url: Url pointing to file to upload - Returns: - MEAPredictionImport - """ - if requests.head(url): - query_str = cls._get_url_mutation() - return cls( - client, - client.execute(query_str, - params={ - "fileUrl": url, - "modelRunId": model_run_id, - 'name': name - })["createModelErrorAnalysisPredictionImport"]) - else: - raise ValueError(f"Url {url} is not reachable") - - @classmethod - def from_name(cls, - client: "labelbox.Client", - model_run_id: str, - name: str, - as_json: bool = False) -> "MEAPredictionImport": - """ - Retrieves an MEA import job. - - Args: - client: Labelbox Client for executing queries - model_run_id: ID used for querying import jobs - name: Name of the import job. - Returns: - MEAPredictionImport - """ - query_str = """query getModelErrorAnalysisPredictionImportPyApi($modelRunId : ID!, $name: String!) { - modelErrorAnalysisPredictionImport( - where: {modelRunId: $modelRunId, name: $name}){ - %s - }}""" % query.results_query_part(cls) - params = { - "modelRunId": model_run_id, - "name": name, - } - response = client.execute(query_str, params) - if response is None: - raise labelbox.exceptions.ResourceNotFoundError( - MEAPredictionImport, params) - response = response["modelErrorAnalysisPredictionImport"] - if as_json: - return response - return cls(client, response) - - @classmethod - def _get_url_mutation(cls) -> str: - return """mutation createMEAPredictionImportByUrlPyApi($modelRunId : ID!, $name: String!, $fileUrl: String!) { - createModelErrorAnalysisPredictionImport(data: { - modelRunId: $modelRunId - name: $name - fileUrl: $fileUrl - }) {%s} - }""" % query.results_query_part(cls) - - @classmethod - def _get_file_mutation(cls) -> str: - return """mutation createMEAPredictionImportByFilePyApi($modelRunId : ID!, $name: String!, $file: Upload!, $contentLength: Int!) { - createModelErrorAnalysisPredictionImport(data: { - modelRunId: $modelRunId name: $name filePayload: { file: $file, contentLength: $contentLength} - }) {%s} - }""" % query.results_query_part(cls) - - @classmethod - def _create_mea_import_from_bytes( - cls, client: "labelbox.Client", model_run_id: str, name: str, - bytes_data: BinaryIO, content_len: int) -> "MEAPredictionImport": - file_name = f"{model_run_id}__{name}.ndjson" - variables = { - "file": None, - "contentLength": content_len, - "modelRunId": model_run_id, - "name": name - } - query_str = cls._get_file_mutation() - res = cls._create_from_bytes( - client, - variables, - query_str, - file_name, - bytes_data, - ) - return cls(client, res["createModelErrorAnalysisPredictionImport"]) - - -class MEAToMALPredictionImport(AnnotationImport): - project = Relationship.ToOne("Project", cache=True) - - @property - def parent_id(self) -> str: - """ - Identifier for this import. Used to refresh the status - """ - return self.project().uid - - @classmethod - def create_for_model_run_data_rows(cls, client: "labelbox.Client", - model_run_id: str, - data_row_ids: List[str], project_id: str, - name: str) -> "MEAToMALPredictionImport": - """ - Create an MEA to MAL prediction import job from a list of data row ids of a specific model run - - Args: - client: Labelbox Client for executing queries - data_row_ids: A list of data row ids - model_run_id: model run id - Returns: - MEAToMALPredictionImport - """ - query_str = cls._get_model_run_data_rows_mutation() - return cls( - client, - client.execute(query_str, - params={ - "dataRowIds": data_row_ids, - "modelRunId": model_run_id, - "projectId": project_id, - "name": name - })["createMalPredictionImportForModelRunDataRows"]) - - @classmethod - def from_name(cls, - client: "labelbox.Client", - project_id: str, - name: str, - as_json: bool = False) -> "MEAToMALPredictionImport": - """ - Retrieves an MEA to MAL import job. - - Args: - client: Labelbox Client for executing queries - project_id: ID used for querying import jobs - name: Name of the import job. - Returns: - MALPredictionImport - """ - query_str = """query getMEAToMALPredictionImportPyApi($projectId : ID!, $name: String!) { - meaToMalPredictionImport( - where: {projectId: $projectId, name: $name}){ - %s - }}""" % query.results_query_part(cls) - params = { - "projectId": project_id, - "name": name, - } - response = client.execute(query_str, params) - if response is None: - raise labelbox.exceptions.ResourceNotFoundError( - MALPredictionImport, params) - response = response["meaToMalPredictionImport"] - if as_json: - return response - return cls(client, response) - - @classmethod - def _get_model_run_data_rows_mutation(cls) -> str: - return """mutation createMalPredictionImportForModelRunDataRowsPyApi($dataRowIds: [ID!]!, $name: String!, $modelRunId: ID!, $projectId:ID!) { - createMalPredictionImportForModelRunDataRows(data: { - name: $name - modelRunId: $modelRunId - dataRowIds: $dataRowIds - projectId: $projectId - }) {%s} - }""" % query.results_query_part(cls) - - -class MALPredictionImport(AnnotationImport): - project = Relationship.ToOne("Project", cache=True) - - @property - def parent_id(self) -> str: - """ - Identifier for this import. Used to refresh the status - """ - return self.project().uid - - @classmethod - def create_from_file(cls, client: "labelbox.Client", project_id: str, - name: str, path: str) -> "MALPredictionImport": - """ - Create an MAL prediction import job from a file of annotations - - Args: - client: Labelbox Client for executing queries - project_id: Project to import labels into - name: Name of the import job. Can be used to reference the task later - path: Path to ndjson file containing annotations - Returns: - MALPredictionImport - """ - if os.path.exists(path): - with open(path, 'rb') as f: - return cls._create_mal_import_from_bytes( - client, project_id, name, f, - os.stat(path).st_size) - else: - raise ValueError(f"File {path} is not accessible") - - @classmethod - def create_from_objects( - cls, client: "labelbox.Client", project_id: str, name: str, - predictions: Union[List[Dict[str, Any]], List["Label"]] - ) -> "MALPredictionImport": - """ - Create an MAL prediction import job from an in memory dictionary - - Args: - client: Labelbox Client for executing queries - project_id: Project to import labels into - name: Name of the import job. Can be used to reference the task later - predictions: List of prediction annotations - Returns: - MALPredictionImport - """ - - data = cls._get_ndjson_from_objects(predictions, 'annotations') - if len(predictions) > 0 and isinstance(predictions[0], Dict): - predictions_dicts = cast(List[Dict[str, Any]], predictions) - has_confidence = LabelsConfidencePresenceChecker.check( - predictions_dicts) - if has_confidence: - logger.warning(""" - Confidence scores are not supported in MAL Prediction Import. - Corresponding confidence score values will be ignored. - """) - return cls._create_mal_import_from_bytes(client, project_id, name, data, - len(str(data))) - - @classmethod - def create_from_url(cls, client: "labelbox.Client", project_id: str, - name: str, url: str) -> "MALPredictionImport": - """ - Create an MAL prediction import job from a url - The url must point to a file containing prediction annotations. - - Args: - client: Labelbox Client for executing queries - project_id: Project to import labels into - name: Name of the import job. Can be used to reference the task later - url: Url pointing to file to upload - Returns: - MALPredictionImport - """ - if requests.head(url): - query_str = cls._get_url_mutation() - return cls( - client, - client.execute( - query_str, - params={ - "fileUrl": url, - "projectId": project_id, - 'name': name - })["createModelAssistedLabelingPredictionImport"]) - else: - raise ValueError(f"Url {url} is not reachable") - - @classmethod - def from_name(cls, - client: "labelbox.Client", - project_id: str, - name: str, - as_json: bool = False) -> "MALPredictionImport": - """ - Retrieves an MAL import job. - - Args: - client: Labelbox Client for executing queries - project_id: ID used for querying import jobs - name: Name of the import job. - Returns: - MALPredictionImport - """ - query_str = """query getModelAssistedLabelingPredictionImportPyApi($projectId : ID!, $name: String!) { - modelAssistedLabelingPredictionImport( - where: {projectId: $projectId, name: $name}){ - %s - }}""" % query.results_query_part(cls) - params = { - "projectId": project_id, - "name": name, - } - response = client.execute(query_str, params) - if response is None: - raise labelbox.exceptions.ResourceNotFoundError( - MALPredictionImport, params) - response = response["modelAssistedLabelingPredictionImport"] - if as_json: - return response - return cls(client, response) - - @classmethod - def _get_url_mutation(cls) -> str: - return """mutation createMALPredictionImportByUrlPyApi($projectId : ID!, $name: String!, $fileUrl: String!) { - createModelAssistedLabelingPredictionImport(data: { - projectId: $projectId - name: $name - fileUrl: $fileUrl - }) {%s} - }""" % query.results_query_part(cls) - - @classmethod - def _get_file_mutation(cls) -> str: - return """mutation createMALPredictionImportByFilePyApi($projectId : ID!, $name: String!, $file: Upload!, $contentLength: Int!) { - createModelAssistedLabelingPredictionImport(data: { - projectId: $projectId name: $name filePayload: { file: $file, contentLength: $contentLength} - }) {%s} - }""" % query.results_query_part(cls) - - @classmethod - def _create_mal_import_from_bytes( - cls, client: "labelbox.Client", project_id: str, name: str, - bytes_data: BinaryIO, content_len: int) -> "MALPredictionImport": - file_name = f"{project_id}__{name}.ndjson" - variables = { - "file": None, - "contentLength": content_len, - "projectId": project_id, - "name": name - } - query_str = cls._get_file_mutation() - res = cls._create_from_bytes(client, variables, query_str, file_name, - bytes_data) - return cls(client, res["createModelAssistedLabelingPredictionImport"]) - - -class LabelImport(AnnotationImport): - project = Relationship.ToOne("Project", cache=True) - - @property - def parent_id(self) -> str: - """ - Identifier for this import. Used to refresh the status - """ - return self.project().uid - - @classmethod - def create_from_file(cls, client: "labelbox.Client", project_id: str, - name: str, path: str) -> "LabelImport": - """ - Create a label import job from a file of annotations - - Args: - client: Labelbox Client for executing queries - project_id: Project to import labels into - name: Name of the import job. Can be used to reference the task later - path: Path to ndjson file containing annotations - Returns: - LabelImport - """ - if os.path.exists(path): - with open(path, 'rb') as f: - return cls._create_label_import_from_bytes( - client, project_id, name, f, - os.stat(path).st_size) - else: - raise ValueError(f"File {path} is not accessible") - - @classmethod - def create_from_objects( - cls, client: "labelbox.Client", project_id: str, name: str, - labels: Union[List[Dict[str, Any]], - List["Label"]]) -> "LabelImport": - """ - Create a label import job from an in memory dictionary - - Args: - client: Labelbox Client for executing queries - project_id: Project to import labels into - name: Name of the import job. Can be used to reference the task later - labels: List of labels - Returns: - LabelImport - """ - data = cls._get_ndjson_from_objects(labels, 'labels') - - if len(labels) > 0 and isinstance(labels[0], Dict): - label_dicts = cast(List[Dict[str, Any]], labels) - has_confidence = LabelsConfidencePresenceChecker.check(label_dicts) - if has_confidence: - logger.warning(""" - Confidence scores are not supported in Label Import. - Corresponding confidence score values will be ignored. - """) - return cls._create_label_import_from_bytes(client, project_id, name, - data, len(str(data))) - - @classmethod - def create_from_url(cls, client: "labelbox.Client", project_id: str, - name: str, url: str) -> "LabelImport": - """ - Create a label annotation import job from a url - The url must point to a file containing label annotations. - - Args: - client: Labelbox Client for executing queries - project_id: Project to import labels into - name: Name of the import job. Can be used to reference the task later - url: Url pointing to file to upload - Returns: - LabelImport - """ - if requests.head(url): - query_str = cls._get_url_mutation() - return cls( - client, - client.execute(query_str, - params={ - "fileUrl": url, - "projectId": project_id, - 'name': name - })["createLabelImport"]) - else: - raise ValueError(f"Url {url} is not reachable") - - @classmethod - def from_name(cls, - client: "labelbox.Client", - project_id: str, - name: str, - as_json: bool = False) -> "LabelImport": - """ - Retrieves an label import job. - - Args: - client: Labelbox Client for executing queries - project_id: ID used for querying import jobs - name: Name of the import job. - Returns: - LabelImport - """ - query_str = """query getLabelImportPyApi($projectId : ID!, $name: String!) { - labelImport( - where: {projectId: $projectId, name: $name}){ - %s - }}""" % query.results_query_part(cls) - params = { - "projectId": project_id, - "name": name, - } - response = client.execute(query_str, params) - if response is None: - raise labelbox.exceptions.ResourceNotFoundError(LabelImport, params) - response = response["labelImport"] - if as_json: - return response - return cls(client, response) - - @classmethod - def _get_url_mutation(cls) -> str: - return """mutation createLabelImportByUrlPyApi($projectId : ID!, $name: String!, $fileUrl: String!) { - createLabelImport(data: { - projectId: $projectId - name: $name - fileUrl: $fileUrl - }) {%s} - }""" % query.results_query_part(cls) - - @classmethod - def _get_file_mutation(cls) -> str: - return """mutation createLabelImportByFilePyApi($projectId : ID!, $name: String!, $file: Upload!, $contentLength: Int!) { - createLabelImport(data: { - projectId: $projectId name: $name filePayload: { file: $file, contentLength: $contentLength} - }) {%s} - }""" % query.results_query_part(cls) - - @classmethod - def _create_label_import_from_bytes(cls, client: "labelbox.Client", - project_id: str, name: str, - bytes_data: BinaryIO, - content_len: int) -> "LabelImport": - file_name = f"{project_id}__{name}.ndjson" - variables = { - "file": None, - "contentLength": content_len, - "projectId": project_id, - "name": name - } - query_str = cls._get_file_mutation() - res = cls._create_from_bytes(client, variables, query_str, file_name, - bytes_data) - return cls(client, res["createLabelImport"]) - ----- -labelbox/schema/organization.py -import json -from typing import TYPE_CHECKING, List, Optional, Dict - -from labelbox.exceptions import LabelboxError -from labelbox import utils -from labelbox.orm.db_object import DbObject, query, Entity -from labelbox.orm.model import Field, Relationship -from labelbox.schema.invite import InviteLimit -from labelbox.schema.resource_tag import ResourceTag - -if TYPE_CHECKING: - from labelbox import Role, User, ProjectRole, Invite, InviteLimit, IAMIntegration - - -class Organization(DbObject): - """ An Organization is a group of Users. - - It is associated with data created by Users within that Organization. - Typically all Users within an Organization have access to data created by any User in the same Organization. - - Attributes: - updated_at (datetime) - created_at (datetime) - name (str) - - users (Relationship): `ToMany` relationship to User - projects (Relationship): `ToMany` relationship to Project - webhooks (Relationship): `ToMany` relationship to Webhook - """ - - # RelationshipManagers in Organization use the type in Query (and - # not the source object) because the server-side does not support - # filtering on ID in the query for getting a single organization. - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - for relationship in self.relationships(): - getattr(self, relationship.name).filter_on_id = False - - updated_at = Field.DateTime("updated_at") - created_at = Field.DateTime("created_at") - name = Field.String("name") - - # Relationships - users = Relationship.ToMany("User", False) - projects = Relationship.ToMany("Project", True) - webhooks = Relationship.ToMany("Webhook", False) - resource_tags = Relationship.ToMany("ResourceTags", False) - - def invite_user( - self, - email: str, - role: "Role", - project_roles: Optional[List["ProjectRole"]] = None) -> "Invite": - """ - Invite a new member to the org. This will send the user an email invite - - Args: - email (str): email address of the user to invite - role (Role): Role to assign to the user - project_roles (Optional[List[ProjectRoles]]): List of project roles to assign to the User (if they have a project based org role). - - Returns: - Invite for the user - - Notes: - 1. Multiple invites can be sent for the same email. This can only be resolved in the UI for now. - - Future releases of the SDK will support the ability to query and revoke invites to solve this problem (and/or checking on the backend) - 2. Some server side response are unclear (e.g. if the user invites themself `None` is returned which the SDK raises as a `LabelboxError` ) - """ - - if project_roles and role.name != "NONE": - raise ValueError( - f"Project roles cannot be set for a user with organization level permissions. Found role name `{role.name}`, expected `NONE`" - ) - - data_param = "data" - query_str = """mutation createInvitesPyApi($%s: [CreateInviteInput!]){ - createInvites(data: $%s){ invite { id createdAt organizationRoleName inviteeEmail inviter { %s } }}}""" % ( - data_param, data_param, query.results_query_part(Entity.User)) - - projects = [{ - "projectId": project_role.project.uid, - "projectRoleId": project_role.role.uid - } for project_role in project_roles or []] - - res = self.client.execute( - query_str, { - data_param: [{ - "inviterId": self.client.get_user().uid, - "inviteeEmail": email, - "organizationId": self.uid, - "organizationRoleId": role.uid, - "projects": projects - }] - }) - invite_response = res['createInvites'][0]['invite'] - if not invite_response: - raise LabelboxError(f"Unable to send invite for email {email}") - return Entity.Invite(self.client, invite_response) - - def invite_limit(self) -> InviteLimit: - """ Retrieve invite limits for the org - This already accounts for users currently in the org - Meaining that `used = users + invites, remaining = limit - (users + invites)` - - Returns: - InviteLimit - - """ - org_id_param = "organizationId" - res = self.client.execute( - """query InvitesLimitPyApi($%s: ID!) { - invitesLimit(where: {id: $%s}) { used limit remaining } - }""" % (org_id_param, org_id_param), {org_id_param: self.uid}) - return InviteLimit(**{ - utils.snake_case(k): v for k, v in res['invitesLimit'].items() - }) - - def remove_user(self, user: "User") -> None: - """ - Deletes a user from the organization. This cannot be undone without sending another invite. - - Args: - user (User): The user to delete from the org - """ - - user_id_param = "userId" - self.client.execute( - """mutation DeleteMemberPyApi($%s: ID!) { - updateUser(where: {id: $%s}, data: {deleted: true}) { id deleted } - }""" % (user_id_param, user_id_param), {user_id_param: user.uid}) - - def create_resource_tag(self, tag: Dict[str, str]) -> ResourceTag: - """ - Creates a resource tag. - >>> tag = {'text': 'tag-1', 'color': 'ffffff'} - - Args: - tag (dict): A resource tag {'text': 'tag-1', 'color': 'fffff'} - Returns: - The created resource tag. - """ - tag_text_param = "text" - tag_color_param = "color" - - query_str = """mutation CreateResourceTagPyApi($text:String!,$color:String!) { - createResourceTag(input:{text:$%s,color:$%s}) {%s}} - """ % (tag_text_param, tag_color_param, - query.results_query_part(ResourceTag)) - - params = { - tag_text_param: tag.get("text", None), - tag_color_param: tag.get("color", None) - } - if not all(params.values()): - raise ValueError( - f"tag must contain 'text' and 'color' keys. received: {tag}") - - res = self.client.execute(query_str, params) - return ResourceTag(self.client, res['createResourceTag']) - - def get_resource_tags(self) -> List[ResourceTag]: - """ - Returns all resource tags for an organization - """ - query_str = """query GetOrganizationResourceTagsPyApi{organization{resourceTag{%s}}}""" % ( - query.results_query_part(ResourceTag)) - - return [ - ResourceTag(self.client, tag) for tag in self.client.execute( - query_str)['organization']['resourceTag'] - ] - - def get_iam_integrations(self) -> List["IAMIntegration"]: - """ - Returns all IAM Integrations for an organization - """ - res = self.client.execute( - """query getAllIntegrationsPyApi { iamIntegrations { - %s - settings { - __typename - ... on AwsIamIntegrationSettings {roleArn} - ... on GcpIamIntegrationSettings {serviceAccountEmailId readBucket} - } - - } } """ % query.results_query_part(Entity.IAMIntegration)) - return [ - Entity.IAMIntegration(self.client, integration_data) - for integration_data in res['iamIntegrations'] - ] - - def get_default_iam_integration(self) -> Optional["IAMIntegration"]: - """ - Returns the default IAM integration for the organization. - Will return None if there are no default integrations for the org. - """ - integrations = self.get_iam_integrations() - default_integration = [ - integration for integration in integrations - if integration.is_org_default - ] - if len(default_integration) > 1: - raise ValueError( - "Found more than one default signer. Please contact Labelbox to resolve" - ) - return None if not len( - default_integration) else default_integration.pop() - ----- -labelbox/schema/confidence_presence_checker.py -from typing import Any, Dict, List, Set - - -class LabelsConfidencePresenceChecker: - """ - Checks if a given list of labels contains at least one confidence score - """ - - @classmethod - def check(cls, raw_labels: List[Dict[str, Any]]): - keys: Set[str] = set([]) - cls._collect_keys_from_list(raw_labels, keys) - return len(keys.intersection(set(["confidence"]))) == 1 - - @classmethod - def _collect_keys_from_list(cls, objects: List[Dict[str, Any]], - keys: Set[str]): - for obj in objects: - if isinstance(obj, (list, tuple)): - cls._collect_keys_from_list(obj, keys) - elif isinstance(obj, dict): - cls._collect_keys_from_object(obj, keys) - - @classmethod - def _collect_keys_from_object(cls, object: Dict[str, Any], keys: Set[str]): - for key in object: - keys.add(key) - if isinstance(object[key], dict): - cls._collect_keys_from_object(object[key], keys) - if isinstance(object[key], (list, tuple)): - cls._collect_keys_from_list(object[key], keys) - ----- -labelbox/schema/batch.py -from typing import Generator, TYPE_CHECKING - -from labelbox.orm.db_object import DbObject, experimental -from labelbox.orm import query -from labelbox.orm.model import Entity, Field, Relationship -from labelbox.exceptions import LabelboxError, ResourceNotFoundError -from io import StringIO -from labelbox import parser -import requests -import logging -import time -import warnings - -if TYPE_CHECKING: - from labelbox import Project - -logger = logging.getLogger(__name__) - - -class Batch(DbObject): - """ A Batch is a group of data rows submitted to a project for labeling - - Attributes: - name (str) - created_at (datetime) - updated_at (datetime) - deleted (bool) - - project (Relationship): `ToOne` relationship to Project - created_by (Relationship): `ToOne` relationship to User - - """ - name = Field.String("name") - created_at = Field.DateTime("created_at") - updated_at = Field.DateTime("updated_at") - size = Field.Int("size") - consensus_settings = Field.Json("consensus_settings_json") - - # Relationships - created_by = Relationship.ToOne("User") - - def __init__(self, - client, - project_id, - *args, - failed_data_row_ids=[], - **kwargs): - super().__init__(client, *args, **kwargs) - self.project_id = project_id - self._failed_data_row_ids = failed_data_row_ids - - def project(self) -> 'Project': # type: ignore - """ Returns Project which this Batch belongs to - - Raises: - LabelboxError: if the project is not found - """ - query_str = """query getProjectPyApi($projectId: ID!) { - project( - where: {id: $projectId}){ - %s - }}""" % query.results_query_part(Entity.Project) - params = {"projectId": self.project_id} - response = self.client.execute(query_str, params) - - if response is None: - raise ResourceNotFoundError(Entity.Project, params) - - return Entity.Project(self.client, response["project"]) - - def remove_queued_data_rows(self) -> None: - """ Removes remaining queued data rows from the batch and labeling queue. - - Args: - batch (Batch): Batch to remove queued data rows from - """ - - project_id_param = "projectId" - batch_id_param = "batchId" - self.client.execute( - """mutation RemoveQueuedDataRowsFromBatchPyApi($%s: ID!, $%s: ID!) { - project(where: {id: $%s}) { removeQueuedDataRowsFromBatch(batchId: $%s) { id } } - }""" % (project_id_param, batch_id_param, project_id_param, - batch_id_param), { - project_id_param: self.project_id, - batch_id_param: self.uid - }, - experimental=True) - - def export_data_rows(self, - timeout_seconds=120, - include_metadata: bool = False) -> Generator: - """ Returns a generator that produces all data rows that are currently - in this batch. - - Note: For efficiency, the data are cached for 30 minutes. Newly created data rows will not appear - until the end of the cache period. - - Args: - timeout_seconds (float): Max waiting time, in seconds. - include_metadata (bool): True to return related DataRow metadata - Returns: - Generator that yields DataRow objects belonging to this batch. - Raises: - LabelboxError: if the export fails or is unable to download within the specified time. - """ - warnings.warn( - "You are currently utilizing exports v1 for this action, which will be deprecated after April 30th, 2024. We recommend transitioning to exports v2. To view export v2 details, visit our docs: https://docs.labelbox.com/reference/label-export", - DeprecationWarning) - - id_param = "batchId" - metadata_param = "includeMetadataInput" - query_str = """mutation GetBatchDataRowsExportUrlPyApi($%s: ID!, $%s: Boolean!) - {exportBatchDataRows(data:{batchId: $%s , includeMetadataInput: $%s}) {downloadUrl createdAt status}} - """ % (id_param, metadata_param, id_param, metadata_param) - sleep_time = 2 - while True: - res = self.client.execute(query_str, { - id_param: self.uid, - metadata_param: include_metadata - }) - res = res["exportBatchDataRows"] - if res["status"] == "COMPLETE": - download_url = res["downloadUrl"] - response = requests.get(download_url) - response.raise_for_status() - reader = parser.reader(StringIO(response.text)) - return ( - Entity.DataRow(self.client, result) for result in reader) - elif res["status"] == "FAILED": - raise LabelboxError("Data row export failed.") - - timeout_seconds -= sleep_time - if timeout_seconds <= 0: - raise LabelboxError( - f"Unable to export data rows within {timeout_seconds} seconds." - ) - - logger.debug("Batch '%s' data row export, waiting for server...", - self.uid) - time.sleep(sleep_time) - - def delete(self) -> None: - """ Deletes the given batch. - - Note: Batch deletion for batches that has labels is forbidden. - - Args: - batch (Batch): Batch to remove queued data rows from - """ - - project_id_param = "projectId" - batch_id_param = "batchId" - self.client.execute("""mutation DeleteBatchPyApi($%s: ID!, $%s: ID!) { - project(where: {id: $%s}) { deleteBatch(batchId: $%s) { deletedBatchId } } - }""" % (project_id_param, batch_id_param, project_id_param, - batch_id_param), { - project_id_param: self.project_id, - batch_id_param: self.uid - }, - experimental=True) - - def delete_labels(self, set_labels_as_template=False) -> None: - """ Deletes labels that were created for data rows in the batch. - - Args: - batch (Batch): Batch to remove queued data rows from - set_labels_as_template (bool): When set to true, the deleted labels will be kept as templates. - """ - - project_id_param = "projectId" - batch_id_param = "batchId" - type_param = "type" - res = self.client.execute( - """mutation DeleteBatchLabelsPyApi($%s: ID!, $%s: ID!, $%s: DeleteBatchLabelsType!) { - project(where: {id: $%s}) { deleteBatchLabels(batchId: $%s, data:{ type: $%s }) { deletedLabelIds } } - }""" % (project_id_param, batch_id_param, type_param, project_id_param, - batch_id_param, type_param), { - project_id_param: - self.project_id, - batch_id_param: - self.uid, - type_param: - "RequeueDataWithLabelAsTemplate" - if set_labels_as_template else "RequeueData" - }, - experimental=True) - return res - - # modify this function to return an empty list if there are no failed data rows - - @property - def failed_data_row_ids(self): - if self._failed_data_row_ids is None: - self._failed_data_row_ids = [] - - return (x for x in self._failed_data_row_ids) - ----- -labelbox/schema/invite.py -from dataclasses import dataclass - -from labelbox.orm.db_object import DbObject -from labelbox.orm.model import Field -from labelbox.schema.role import ProjectRole, format_role - - -@dataclass -class InviteLimit: - """ - remaining (int): Number of invites remaining in the org - used (int): Number of invites used in the org - limit (int): Maximum number of invites available to the org - """ - - remaining: int - used: int - limit: int - - -class Invite(DbObject): - """ - An object representing a user invite - """ - created_at = Field.DateTime("created_at") - organization_role_name = Field.String("organization_role_name") - email = Field.String("email", "inviteeEmail") - - def __init__(self, client, invite_response): - project_roles = invite_response.pop("projectInvites", []) - super().__init__(client, invite_response) - - self.project_roles = [ - ProjectRole(project=client.get_project(r['projectId']), - role=client.get_roles()[format_role( - r['projectRoleName'])]) for r in project_roles - ] - ----- -labelbox/schema/__init__.py -import labelbox.schema.asset_attachment -import labelbox.schema.bulk_import_request -import labelbox.schema.annotation_import -import labelbox.schema.benchmark -import labelbox.schema.data_row -import labelbox.schema.dataset -import labelbox.schema.invite -import labelbox.schema.label -import labelbox.schema.labeling_frontend -import labelbox.schema.model -import labelbox.schema.model_run -import labelbox.schema.ontology -import labelbox.schema.organization -import labelbox.schema.project -import labelbox.schema.review -import labelbox.schema.role -import labelbox.schema.task -import labelbox.schema.user -import labelbox.schema.webhook -import labelbox.schema.data_row_metadata -import labelbox.schema.batch -import labelbox.schema.iam_integration -import labelbox.schema.media_type -import labelbox.schema.identifiables -import labelbox.schema.identifiable -import labelbox.schema.catalog - ----- -labelbox/schema/asset_attachment.py -import warnings -from enum import Enum -from typing import Dict - -from labelbox.orm.db_object import DbObject -from labelbox.orm.model import Field - - -class AssetAttachment(DbObject): - """Asset attachment provides extra context about an asset while labeling. - - Attributes: - attachment_type (str): IMAGE, VIDEO, IMAGE_OVERLAY, HTML, RAW_TEXT, TEXT_URL, or PDF_URL. TEXT attachment type is deprecated. - attachment_value (str): URL to an external file or a string of text - """ - - class AttachmentType(Enum): - - @classmethod - def __missing__(cls, value: object): - if str(value) == "TEXT": - warnings.warn( - "The TEXT attachment type is deprecated. Use RAW_TEXT instead." - ) - return cls.RAW_TEXT - return value - - VIDEO = "VIDEO" - IMAGE = "IMAGE" - # TEXT = "TEXT" # Deprecated - IMAGE_OVERLAY = "IMAGE_OVERLAY" - HTML = "HTML" - RAW_TEXT = "RAW_TEXT" - TEXT_URL = "TEXT_URL" - PDF_URL = "PDF_URL" - CAMERA_IMAGE = "CAMERA_IMAGE" # Used by experimental point-cloud editor - - for topic in AttachmentType: - vars()[topic.name] = topic.value - - attachment_type = Field.String("attachment_type", "type") - attachment_value = Field.String("attachment_value", "value") - - @classmethod - def validate_attachment_json(cls, attachment_json: Dict[str, str]) -> None: - for required_key in ['type', 'value']: - if required_key not in attachment_json: - raise ValueError( - f"Must provide a `{required_key}` key for each attachment. Found {attachment_json}." - ) - cls.validate_attachment_type(attachment_json['type']) - - @classmethod - def validate_attachment_type(cls, attachment_type: str) -> None: - valid_types = set(cls.AttachmentType.__members__) - if attachment_type not in valid_types: - raise ValueError( - f"meta_type must be one of {valid_types}. Found {attachment_type}" - ) - - def delete(self) -> None: - """Deletes an attachment on the data row.""" - query_str = """mutation deleteDataRowAttachmentPyApi($attachment_id: ID!) { - deleteDataRowAttachment(where: {id: $attachment_id}) { - id} - }""" - self.client.execute(query_str, {"attachment_id": self.uid}) - ----- -labelbox/schema/export_filters.py -import sys - -from datetime import datetime, timezone -from typing import Collection, Dict, Tuple, List, Optional -from labelbox.typing_imports import Literal -if sys.version_info >= (3, 8): - from typing import TypedDict -else: - from typing_extensions import TypedDict - -SEARCH_LIMIT_PER_EXPORT_V2 = 2_000 -ISO_8061_FORMAT = "%Y-%m-%dT%H:%M:%S%z" - - -class BaseExportFilters(TypedDict): - data_row_ids: Optional[List[str]] - """ Datarow ids to export - Please refer to https://docs.labelbox.com/docs/limits#export on the allowed limit of data_row_ids - Example: - >>> ["clgo3lyax0000veeezdbu3ws4", "clgo3lzjl0001veeer6y6z8zp", ...] - - """ - - global_keys: Optional[List[str]] - """ Global keys to export - Please refer to https://docs.labelbox.com/docs/limits#export on the allowed limit of data_row_ids - Example: - >>> ["key1", "key2", ...] - """ - - -class SharedExportFilters(BaseExportFilters): - label_created_at: Optional[Tuple[str, str]] - """ Date range for labels created at - Formatted "YYYY-MM-DD" or "YYYY-MM-DD hh:mm:ss" - Examples: - >>> ["2000-01-01 00:00:00", "2050-01-01 00:00:00"] - >>> [None, "2050-01-01 00:00:00"] - >>> ["2000-01-01 00:00:00", None] - """ - last_activity_at: Optional[Tuple[str, str]] - - -class ProjectExportFilters(SharedExportFilters): - batch_ids: Optional[List[str]] - """ Batch ids to export - Example: - >>> ["clgo3lyax0000veeezdbu3ws4"] - """ - workflow_status: Optional[Literal["ToLabel", "InReview", "InRework", - "Done"]] - """ Export data rows matching workflow status - Example: - >>> "InReview" - """ - - -class DatasetExportFilters(SharedExportFilters): - pass - - -class CatalogExportFilters(SharedExportFilters): - pass - - -class DatarowExportFilters(BaseExportFilters): - pass - - -def validate_datetime(datetime_str: str) -> bool: - """helper function to validate that datetime's format: "YYYY-MM-DD" or "YYYY-MM-DD hh:mm:ss" - or ISO 8061 format "YYYY-MM-DDThh:mm:ss±hhmm" (Example: "2023-05-23T14:30:00+0530")""" - if datetime_str: - for fmt in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S", ISO_8061_FORMAT): - try: - datetime.strptime(datetime_str, fmt) - return True - except ValueError: - pass - raise ValueError(f"""Incorrect format for: {datetime_str}. - Format must be \"YYYY-MM-DD\" or \"YYYY-MM-DD hh:mm:ss\" or ISO 8061 format \"YYYY-MM-DDThh:mm:ss±hhmm\"""" - ) - return True - - -def convert_to_utc_if_iso8061(datetime_str: str, timezone_str: Optional[str]): - """helper function to convert datetime to UTC if it is in ISO_8061_FORMAT and set timezone to UTC""" - try: - date_obj = datetime.strptime(datetime_str, ISO_8061_FORMAT) - date_obj_utc = date_obj.replace(tzinfo=timezone.utc) - datetime_str = date_obj_utc.strftime(ISO_8061_FORMAT) - timezone_str = "UTC" - except ValueError: - pass - return datetime_str, timezone_str - - -def validate_one_of_data_row_ids_or_global_keys(filters): - if filters.get("data_row_ids") is not None and filters.get( - "global_keys") is not None: - raise ValueError( - "data_rows and global_keys cannot both be present in export filters" - ) - - -def validate_at_least_one_of_data_row_ids_or_global_keys(filters): - if not filters.get("data_row_ids") and not filters.get("global_keys"): - raise ValueError("data_rows and global_keys cannot both be empty") - - -def build_filters(client, filters): - search_query: List[Dict[str, Collection[str]]] = [] - timezone: Optional[str] = None - - def _get_timezone() -> str: - timezone_query_str = """query CurrentUserPyApi { user { timezone } }""" - tz_res = client.execute(timezone_query_str) - return tz_res["user"]["timezone"] or "UTC" - - def _build_id_filters(ids: list, - type_name: str, - search_where_limit: int = SEARCH_LIMIT_PER_EXPORT_V2): - if not isinstance(ids, list): - raise ValueError(f"{type_name} filter expects a list.") - if len(ids) == 0: - raise ValueError(f"{type_name} filter expects a non-empty list.") - if len(ids) > search_where_limit: - raise ValueError( - f"{type_name} filter only supports a max of {search_where_limit} items." - ) - search_query.append({"ids": ids, "operator": "is", "type": type_name}) - - validate_one_of_data_row_ids_or_global_keys(filters) - - last_activity_at = filters.get("last_activity_at") - if last_activity_at: - timezone = _get_timezone() - start, end = last_activity_at - if (start is not None and end is not None): - [validate_datetime(date) for date in last_activity_at] - start, timezone = convert_to_utc_if_iso8061(start, timezone) - end, timezone = convert_to_utc_if_iso8061(end, timezone) - search_query.append({ - "type": "data_row_last_activity_at", - "value": { - "operator": "BETWEEN", - "timezone": timezone, - "value": { - "min": start, - "max": end - } - } - }) - elif (start is not None): - validate_datetime(start) - start, timezone = convert_to_utc_if_iso8061(start, timezone) - search_query.append({ - "type": "data_row_last_activity_at", - "value": { - "operator": "GREATER_THAN_OR_EQUAL", - "timezone": timezone, - "value": start - } - }) - elif (end is not None): - validate_datetime(end) - end, timezone = convert_to_utc_if_iso8061(end, timezone) - search_query.append({ - "type": "data_row_last_activity_at", - "value": { - "operator": "LESS_THAN_OR_EQUAL", - "timezone": timezone, - "value": end - } - }) - - label_created_at = filters.get("label_created_at") - if label_created_at: - timezone = _get_timezone() - start, end = label_created_at - if (start is not None and end is not None): - [validate_datetime(date) for date in label_created_at] - start, timezone = convert_to_utc_if_iso8061(start, timezone) - end, timezone = convert_to_utc_if_iso8061(end, timezone) - search_query.append({ - "type": "labeled_at", - "value": { - "operator": "BETWEEN", - "timezone": timezone, - "value": { - "min": start, - "max": end - } - } - }) - elif (start is not None): - validate_datetime(start) - start, timezone = convert_to_utc_if_iso8061(start, timezone) - search_query.append({ - "type": "labeled_at", - "value": { - "operator": "GREATER_THAN_OR_EQUAL", - "timezone": timezone, - "value": start - } - }) - elif (end is not None): - validate_datetime(end) - end, timezone = convert_to_utc_if_iso8061(end, timezone) - search_query.append({ - "type": "labeled_at", - "value": { - "operator": "LESS_THAN_OR_EQUAL", - "timezone": timezone, - "value": end - } - }) - - data_row_ids = filters.get("data_row_ids") - if data_row_ids is not None: - _build_id_filters(data_row_ids, "data_row_id") - - global_keys = filters.get("global_keys") - if global_keys is not None: - _build_id_filters(global_keys, "global_key") - - batch_ids = filters.get("batch_ids") - if batch_ids is not None: - _build_id_filters(batch_ids, "batch") - - workflow_status = filters.get("workflow_status") - if workflow_status: - if not isinstance(workflow_status, str): - raise ValueError("`workflow_status` filter expects a string.") - elif workflow_status not in ["ToLabel", "InReview", "InRework", "Done"]: - raise ValueError( - "`workflow_status` filter expects one of 'InReview', 'InRework', or 'Done'." - ) - - if workflow_status == "ToLabel": - search_query.append({"type": "task_queue_not_exist"}) - else: - search_query.append({ - "type": 'task_queue_status', - "status": workflow_status - }) - - return search_query - ----- -labelbox/schema/role.py -from dataclasses import dataclass -from typing import Dict, Optional, TYPE_CHECKING - -from labelbox.orm.model import Field, Entity -from labelbox.orm.db_object import DbObject - -if TYPE_CHECKING: - from labelbox import Client, Project - -_ROLES: Optional[Dict[str, "Role"]] = None - - -def get_roles(client: "Client") -> Dict[str, "Role"]: - global _ROLES - if _ROLES is None: - query_str = """query GetAvailableUserRolesPyApi { roles { id name } }""" - res = client.execute(query_str) - _ROLES = {} - for role in res['roles']: - role['name'] = format_role(role['name']) - _ROLES[role['name']] = Role(client, role) - return _ROLES - - -def format_role(name: str): - return name.upper().replace(' ', '_') - - -class Role(DbObject): - name = Field.String("name") - - -class OrgRole(Role): - ... - - -class UserRole(Role): - ... - - -@dataclass -class ProjectRole: - project: "Project" - role: Role - ----- -labelbox/schema/export_params.py -import sys - -from typing import Optional, List - -EXPORT_LIMIT = 30 - -from labelbox.schema.media_type import MediaType -if sys.version_info >= (3, 8): - from typing import TypedDict -else: - from typing_extensions import TypedDict - - -class DataRowParams(TypedDict): - data_row_details: Optional[bool] - metadata_fields: Optional[bool] - attachments: Optional[bool] - media_type_override: Optional[MediaType] - - -class ProjectExportParams(DataRowParams): - project_details: Optional[bool] - label_details: Optional[bool] - performance_details: Optional[bool] - interpolated_frames: Optional[bool] - - -class CatalogExportParams(DataRowParams): - project_details: Optional[bool] - label_details: Optional[bool] - performance_details: Optional[bool] - model_run_ids: Optional[List[str]] - project_ids: Optional[List[str]] - interpolated_frames: Optional[bool] - all_projects: Optional[bool] - all_model_runs: Optional[bool] - - -class ModelRunExportParams(DataRowParams): - predictions: Optional[bool] - model_run_details: Optional[bool] - - -def _validate_array_length(array, max_length, array_name): - if len(array) > max_length: - raise ValueError(f"{array_name} cannot exceed {max_length} items") - - -def validate_catalog_export_params(params: CatalogExportParams): - if "model_run_ids" in params and params["model_run_ids"] is not None: - _validate_array_length(params["model_run_ids"], EXPORT_LIMIT, - "model_run_ids") - - if "project_ids" in params and params["project_ids"] is not None: - _validate_array_length(params["project_ids"], EXPORT_LIMIT, - "project_ids") - ----- -labelbox/schema/review.py -from enum import Enum, auto - -from labelbox.orm.db_object import DbObject, Updateable, Deletable -from labelbox.orm.model import Field, Relationship - - -class Review(DbObject, Deletable, Updateable): - """ Reviewing labeled data is a collaborative quality assurance technique. - - A Review object indicates the quality of the assigned Label. The aggregated - review numbers can be obtained on a Project object. - - Attributes: - created_at (datetime) - updated_at (datetime) - score (float) - - created_by (Relationship): `ToOne` relationship to User - organization (Relationship): `ToOne` relationship to Organization - project (Relationship): `ToOne` relationship to Project - label (Relationship): `ToOne` relationship to Label - """ - - class NetScore(Enum): - """ Negative, Zero, or Positive. - """ - Negative = auto() - Zero = auto() - Positive = auto() - - updated_at = Field.DateTime("updated_at") - created_at = Field.DateTime("created_at") - score = Field.Float("score") - - created_by = Relationship.ToOne("User", False, "created_by") - organization = Relationship.ToOne("Organization", False) - project = Relationship.ToOne("Project", False) - label = Relationship.ToOne("Label", False) - ----- -labelbox/schema/model.py -from typing import TYPE_CHECKING -from labelbox.orm import query -from labelbox.orm.db_object import DbObject -from labelbox.orm.model import Entity, Field, Relationship - -if TYPE_CHECKING: - from labelbox import ModelRun - - -class Model(DbObject): - """A model represents a program that has been trained and - can make predictions on new data. - Attributes: - name (str) - model_runs (Relationship): `ToMany` relationship to ModelRun - """ - - name = Field.String("name") - ontology_id = Field.String("ontology_id") - model_runs = Relationship.ToMany("ModelRun", False) - - def create_model_run(self, name, config=None) -> "ModelRun": - """ Creates a model run belonging to this model. - - Args: - name (string): The name for the model run. - config (json): Model run's training metadata config - Returns: - ModelRun, the created model run. - """ - name_param = "name" - config_param = "config" - model_id_param = "modelId" - ModelRun = Entity.ModelRun - query_str = """mutation CreateModelRunPyApi($%s: String!, $%s: Json, $%s: ID!) { - createModelRun(data: {name: $%s, trainingMetadata: $%s, modelId: $%s}) {%s}}""" % ( - name_param, config_param, model_id_param, name_param, config_param, - model_id_param, query.results_query_part(ModelRun)) - res = self.client.execute(query_str, { - name_param: name, - config_param: config, - model_id_param: self.uid - }) - return ModelRun(self.client, res["createModelRun"]) - - def delete(self) -> None: - """ Deletes specified model. - - Returns: - Query execution success. - """ - ids_param = "ids" - query_str = """mutation DeleteModelPyApi($%s: ID!) { - deleteModels(where: {ids: [$%s]})}""" % (ids_param, ids_param) - self.client.execute(query_str, {ids_param: str(self.uid)}) - ----- -labelbox/schema/dataset.py -from typing import Dict, Generator, List, Optional, Union, Any -import os -import json -import logging -from collections.abc import Iterable -from string import Template -import time -import warnings - -from labelbox import parser -from itertools import islice - -from concurrent.futures import ThreadPoolExecutor, as_completed -from io import StringIO -import requests - -from labelbox import pagination -from labelbox.exceptions import InvalidQueryError, LabelboxError, ResourceNotFoundError, InvalidAttributeError -from labelbox.orm.comparison import Comparison -from labelbox.orm.db_object import DbObject, Updateable, Deletable, experimental -from labelbox.orm.model import Entity, Field, Relationship -from labelbox.orm import query -from labelbox.exceptions import MalformedQueryException -from labelbox.pagination import PaginatedCollection -from labelbox.schema.data_row import DataRow -from labelbox.schema.export_filters import DatasetExportFilters, build_filters -from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params -from labelbox.schema.export_task import ExportTask -from labelbox.schema.task import Task -from labelbox.schema.user import User - -logger = logging.getLogger(__name__) - -MAX_DATAROW_PER_API_OPERATION = 150_000 - - -class Dataset(DbObject, Updateable, Deletable): - """ A Dataset is a collection of DataRows. - - Attributes: - name (str) - description (str) - updated_at (datetime) - created_at (datetime) - row_count (int): The number of rows in the dataset. Fetch the dataset again to update since this is cached. - - created_by (Relationship): `ToOne` relationship to User - organization (Relationship): `ToOne` relationship to Organization - - """ - name = Field.String("name") - description = Field.String("description") - updated_at = Field.DateTime("updated_at") - created_at = Field.DateTime("created_at") - row_count = Field.Int("row_count") - - # Relationships - created_by = Relationship.ToOne("User", False, "created_by") - organization = Relationship.ToOne("Organization", False) - iam_integration = Relationship.ToOne("IAMIntegration", False, - "iam_integration", "signer") - - def data_rows( - self, - from_cursor: Optional[str] = None, - where: Optional[Comparison] = None, - ) -> PaginatedCollection: - """ - Custom method to paginate data_rows via cursor. - - Args: - from_cursor (str): Cursor (data row id) to start from, if none, will start from the beginning - where (dict(str,str)): Filter to apply to data rows. Where value is a data row column name and key is the value to filter on. - example: {'external_id': 'my_external_id'} to get a data row with external_id = 'my_external_id' - - - NOTE: - Order of retrieval is newest data row first. - Deleted data rows are not retrieved. - Failed data rows are not retrieved. - Data rows in progress *maybe* retrieved. - """ - - page_size = 500 # hardcode to avoid overloading the server - where_param = query.where_as_dict(Entity.DataRow, - where) if where is not None else None - - template = Template( - """query DatasetDataRowsPyApi($$id: ID!, $$from: ID, $$first: Int, $$where: DatasetDataRowWhereInput) { - datasetDataRows(id: $$id, from: $$from, first: $$first, where: $$where) - { - nodes { $datarow_selections } - pageInfo { hasNextPage startCursor } - } - } - """) - query_str = template.substitute( - datarow_selections=query.results_query_part(Entity.DataRow)) - - params = { - 'id': self.uid, - 'from': from_cursor, - 'first': page_size, - 'where': where_param, - } - - return PaginatedCollection( - client=self.client, - query=query_str, - params=params, - dereferencing=['datasetDataRows', 'nodes'], - obj_class=Entity.DataRow, - cursor_path=['datasetDataRows', 'pageInfo', 'startCursor'], - ) - - def create_data_row(self, items=None, **kwargs) -> "DataRow": - """ Creates a single DataRow belonging to this dataset. - - >>> dataset.create_data_row(row_data="http://my_site.com/photos/img_01.jpg") - - Args: - items: Dictionary containing new `DataRow` data. At a minimum, - must contain `row_data` or `DataRow.row_data`. - **kwargs: Key-value arguments containing new `DataRow` data. At a minimum, - must contain `row_data`. - - Raises: - InvalidQueryError: If both dictionary and `kwargs` are provided as inputs - InvalidQueryError: If `DataRow.row_data` field value is not provided - in `kwargs`. - InvalidAttributeError: in case the DB object type does not contain - any of the field names given in `kwargs`. - - """ - invalid_argument_error = "Argument to create_data_row() must be either a dictionary, or kwargs containing `row_data` at minimum" - - def convert_field_keys(items): - if not isinstance(items, dict): - raise InvalidQueryError(invalid_argument_error) - return { - key.name if isinstance(key, Field) else key: value - for key, value in items.items() - } - - if items is not None and len(kwargs) > 0: - raise InvalidQueryError(invalid_argument_error) - - DataRow = Entity.DataRow - args = convert_field_keys(items) if items is not None else kwargs - - if DataRow.row_data.name not in args: - raise InvalidQueryError( - "DataRow.row_data missing when creating DataRow.") - - row_data = args[DataRow.row_data.name] - - if isinstance(row_data, str) and row_data.startswith("s3:/"): - raise InvalidQueryError( - "row_data: s3 assets must start with 'https'.") - - if not isinstance(row_data, str): - # If the row data is an object, upload as a string - args[DataRow.row_data.name] = json.dumps(row_data) - elif os.path.exists(row_data): - # If row data is a local file path, upload it to server. - args[DataRow.row_data.name] = self.client.upload_file(row_data) - - # Parse metadata fields, if they are provided - if DataRow.metadata_fields.name in args: - mdo = self.client.get_data_row_metadata_ontology() - args[DataRow.metadata_fields.name] = mdo.parse_upsert_metadata( - args[DataRow.metadata_fields.name]) - - query_str = """mutation CreateDataRowPyApi( - $row_data: String!, - $metadata_fields: [DataRowCustomMetadataUpsertInput!], - $attachments: [DataRowAttachmentInput!], - $media_type : MediaType, - $external_id : String, - $global_key : String, - $dataset: ID! - ){ - createDataRow( - data: - { - rowData: $row_data - mediaType: $media_type - metadataFields: $metadata_fields - externalId: $external_id - globalKey: $global_key - attachments: $attachments - dataset: {connect: {id: $dataset}} - } - ) - {%s} - } - """ % query.results_query_part(Entity.DataRow) - res = self.client.execute(query_str, {**args, 'dataset': self.uid}) - return DataRow(self.client, res['createDataRow']) - - def create_data_rows_sync(self, items) -> None: - """ Synchronously bulk upload data rows. - - Use this instead of `Dataset.create_data_rows` for smaller batches of data rows that need to be uploaded quickly. - Cannot use this for uploads containing more than 1000 data rows. - Each data row is also limited to 5 attachments. - - Args: - items (iterable of (dict or str)): - See the docstring for `Dataset._create_descriptor_file` for more information. - Returns: - None. If the function doesn't raise an exception then the import was successful. - - Raises: - InvalidQueryError: If the `items` parameter does not conform to - the specification in Dataset._create_descriptor_file or if the server did not accept the - DataRow creation request (unknown reason). - InvalidAttributeError: If there are fields in `items` not valid for - a DataRow. - ValueError: When the upload parameters are invalid - """ - max_data_rows_supported = 1000 - max_attachments_per_data_row = 5 - if len(items) > max_data_rows_supported: - raise ValueError( - f"Dataset.create_data_rows_sync() supports a max of {max_data_rows_supported} data rows." - " For larger imports use the async function Dataset.create_data_rows()" - ) - descriptor_url = self._create_descriptor_file( - items, max_attachments_per_data_row=max_attachments_per_data_row) - dataset_param = "datasetId" - url_param = "jsonUrl" - query_str = """mutation AppendRowsToDatasetSyncPyApi($%s: ID!, $%s: String!){ - appendRowsToDatasetSync(data:{datasetId: $%s, jsonFileUrl: $%s} - ){dataset{id}}} """ % (dataset_param, url_param, dataset_param, - url_param) - self.client.execute(query_str, { - dataset_param: self.uid, - url_param: descriptor_url - }) - - def create_data_rows(self, items) -> "Task": - """ Asynchronously bulk upload data rows - - Use this instead of `Dataset.create_data_rows_sync` uploads for batches that contain more than 1000 data rows. - - Args: - items (iterable of (dict or str)): See the docstring for `Dataset._create_descriptor_file` for more information - - Returns: - Task representing the data import on the server side. The Task - can be used for inspecting task progress and waiting until it's done. - - Raises: - InvalidQueryError: If the `items` parameter does not conform to - the specification above or if the server did not accept the - DataRow creation request (unknown reason). - ResourceNotFoundError: If unable to retrieve the Task for the - import process. This could imply that the import failed. - InvalidAttributeError: If there are fields in `items` not valid for - a DataRow. - ValueError: When the upload parameters are invalid - """ - descriptor_url = self._create_descriptor_file(items) - # Create data source - dataset_param = "datasetId" - url_param = "jsonUrl" - query_str = """mutation AppendRowsToDatasetPyApi($%s: ID!, $%s: String!){ - appendRowsToDataset(data:{datasetId: $%s, jsonFileUrl: $%s} - ){ taskId accepted errorMessage } } """ % (dataset_param, url_param, - dataset_param, url_param) - - res = self.client.execute(query_str, { - dataset_param: self.uid, - url_param: descriptor_url - }) - res = res["appendRowsToDataset"] - if not res["accepted"]: - msg = res['errorMessage'] - raise InvalidQueryError( - f"Server did not accept DataRow creation request. {msg}") - - # Fetch and return the task. - task_id = res["taskId"] - user: User = self.client.get_user() - tasks: List[Task] = list( - user.created_tasks(where=Entity.Task.uid == task_id)) - # Cache user in a private variable as the relationship can't be - # resolved due to server-side limitations (see Task.created_by) - # for more info. - if len(tasks) != 1: - raise ResourceNotFoundError(Entity.Task, task_id) - task: Task = tasks[0] - task._user = user - return task - - def _create_descriptor_file(self, items, max_attachments_per_data_row=None): - """ - This function is shared by both `Dataset.create_data_rows` and `Dataset.create_data_rows_sync` - to prepare the input file. The user defined input is validated, processed, and json stringified. - Finally the json data is uploaded to gcs and a uri is returned. This uri can be passed to - - - - Each element in `items` can be either a `str` or a `dict`. If - it is a `str`, then it is interpreted as a local file path. The file - is uploaded to Labelbox and a DataRow referencing it is created. - - If an item is a `dict`, then it could support one of the two following structures - 1. For static imagery, video, and text it should map `DataRow` field names to values. - At the minimum an `item` passed as a `dict` must contain a `row_data` key and value. - If the value for row_data is a local file path and the path exists, - then the local file will be uploaded to labelbox. - - 2. For tiled imagery the dict must match the import structure specified in the link below - https://docs.labelbox.com/data-model/en/index-en#tiled-imagery-import - - >>> dataset.create_data_rows([ - >>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"}, - >>> {DataRow.row_data:"/path/to/file1.jpg"}, - >>> "path/to/file2.jpg", - >>> {DataRow.row_data: {"tileLayerUrl" : "http://", ...}} - >>> {DataRow.row_data: {"type" : ..., 'version' : ..., 'messages' : [...]}} - >>> ]) - - For an example showing how to upload tiled data_rows see the following notebook: - https://github.com/Labelbox/labelbox-python/blob/ms/develop/model_assisted_labeling/tiled_imagery_mal.ipynb - - Args: - items (iterable of (dict or str)): See above for details. - max_attachments_per_data_row (Optional[int]): Param used during attachment validation to determine - if the user has provided too many attachments. - - Returns: - uri (string): A reference to the uploaded json data. - - Raises: - InvalidQueryError: If the `items` parameter does not conform to - the specification above or if the server did not accept the - DataRow creation request (unknown reason). - InvalidAttributeError: If there are fields in `items` not valid for - a DataRow. - ValueError: When the upload parameters are invalid - """ - file_upload_thread_count = 20 - DataRow = Entity.DataRow - AssetAttachment = Entity.AssetAttachment - - def upload_if_necessary(item): - row_data = item['row_data'] - if isinstance(row_data, str) and os.path.exists(row_data): - item_url = self.client.upload_file(row_data) - item['row_data'] = item_url - if 'external_id' not in item: - # Default `external_id` to local file name - item['external_id'] = row_data - return item - - def validate_attachments(item): - attachments = item.get('attachments') - if attachments: - if isinstance(attachments, list): - if max_attachments_per_data_row and len( - attachments) > max_attachments_per_data_row: - raise ValueError( - f"Max attachments number of supported attachments per data row is {max_attachments_per_data_row}." - f" Found {len(attachments)}. Condense multiple attachments into one with the HTML attachment type if necessary." - ) - for attachment in attachments: - AssetAttachment.validate_attachment_json(attachment) - else: - raise ValueError( - f"Attachments must be a list. Found {type(attachments)}" - ) - return attachments - - def validate_conversational_data(conversational_data: list) -> None: - """ - Checks each conversational message for keys expected as per https://docs.labelbox.com/reference/text-conversational#sample-conversational-json - - Args: - conversational_data (list): list of dictionaries. - """ - - def check_message_keys(message): - accepted_message_keys = set([ - "messageId", "timestampUsec", "content", "user", "align", - "canLabel" - ]) - for key in message.keys(): - if not key in accepted_message_keys: - raise KeyError( - f"Invalid {key} key found! Accepted keys in messages list is {accepted_message_keys}" - ) - - if conversational_data and not isinstance(conversational_data, - list): - raise ValueError( - f"conversationalData must be a list. Found {type(conversational_data)}" - ) - - [check_message_keys(message) for message in conversational_data] - - def parse_metadata_fields(item): - metadata_fields = item.get('metadata_fields') - if metadata_fields: - mdo = self.client.get_data_row_metadata_ontology() - item['metadata_fields'] = mdo.parse_upsert_metadata( - metadata_fields) - - def format_row(item): - # Formats user input into a consistent dict structure - if isinstance(item, dict): - # Convert fields to strings - item = { - key.name if isinstance(key, Field) else key: value - for key, value in item.items() - } - elif isinstance(item, str): - # The main advantage of using a string over a dict is that the user is specifying - # that the file should exist locally. - # That info is lost after this section so we should check for it here. - if not os.path.exists(item): - raise ValueError(f"Filepath {item} does not exist.") - item = {"row_data": item, "external_id": item} - return item - - def validate_keys(item): - if 'row_data' not in item: - raise InvalidQueryError( - "`row_data` missing when creating DataRow.") - - if isinstance(item.get('row_data'), - str) and item.get('row_data').startswith("s3:/"): - raise InvalidQueryError( - "row_data: s3 assets must start with 'https'.") - invalid_keys = set(item) - { - *{f.name for f in DataRow.fields()}, 'attachments', 'media_type' - } - if invalid_keys: - raise InvalidAttributeError(DataRow, invalid_keys) - return item - - def formatLegacyConversationalData(item): - messages = item.pop("conversationalData") - version = item.pop("version", 1) - type = item.pop("type", "application/vnd.labelbox.conversational") - if "externalId" in item: - external_id = item.pop("externalId") - item["external_id"] = external_id - if "globalKey" in item: - global_key = item.pop("globalKey") - item["globalKey"] = global_key - validate_conversational_data(messages) - one_conversation = \ - { - "type": type, - "version": version, - "messages": messages - } - item["row_data"] = one_conversation - return item - - def convert_item(item): - if "tileLayerUrl" in item: - validate_attachments(item) - return item - - if "conversationalData" in item: - formatLegacyConversationalData(item) - - # Convert all payload variations into the same dict format - item = format_row(item) - # Make sure required keys exist (and there are no extra keys) - validate_keys(item) - # Make sure attachments are valid - validate_attachments(item) - # Parse metadata fields if they exist - parse_metadata_fields(item) - # Upload any local file paths - item = upload_if_necessary(item) - return item - - if not isinstance(items, Iterable): - raise ValueError( - f"Must pass an iterable to create_data_rows. Found {type(items)}" - ) - - if len(items) > MAX_DATAROW_PER_API_OPERATION: - raise MalformedQueryException( - f"Cannot create more than {MAX_DATAROW_PER_API_OPERATION} DataRows per function call." - ) - - with ThreadPoolExecutor(file_upload_thread_count) as executor: - futures = [executor.submit(convert_item, item) for item in items] - items = [future.result() for future in as_completed(futures)] - # Prepare and upload the desciptor file - data = json.dumps(items) - return self.client.upload_data(data, - content_type="application/json", - filename="json_import.json") - - def data_rows_for_external_id(self, - external_id, - limit=10) -> List["DataRow"]: - """ Convenience method for getting a multiple `DataRow` belonging to this - `Dataset` that has the given `external_id`. - - Args: - external_id (str): External ID of the sought `DataRow`. - limit (int): The maximum number of data rows to return for the given external_id - - Returns: - A list of `DataRow` with the given ID. - - Raises: - labelbox.exceptions.ResourceNotFoundError: If there is no `DataRow` - in this `DataSet` with the given external ID, or if there are - multiple `DataRows` for it. - """ - DataRow = Entity.DataRow - where = DataRow.external_id == external_id - - data_rows = self.data_rows(where=where) - # Get at most `limit` data_rows. - at_most_data_rows = list(islice(data_rows, limit)) - - if not len(at_most_data_rows): - raise ResourceNotFoundError(DataRow, where) - return at_most_data_rows - - def data_row_for_external_id(self, external_id) -> "DataRow": - """ Convenience method for getting a single `DataRow` belonging to this - `Dataset` that has the given `external_id`. - - Args: - external_id (str): External ID of the sought `DataRow`. - - Returns: - A single `DataRow` with the given ID. - - Raises: - labelbox.exceptions.ResourceNotFoundError: If there is no `DataRow` - in this `DataSet` with the given external ID, or if there are - multiple `DataRows` for it. - """ - data_rows = self.data_rows_for_external_id(external_id=external_id, - limit=2) - if len(data_rows) > 1: - logger.warning( - f"More than one data_row has the provided external_id : `%s`. Use function data_rows_for_external_id to fetch all", - external_id) - return data_rows[0] - - def export_data_rows(self, - timeout_seconds=120, - include_metadata: bool = False) -> Generator: - """ Returns a generator that produces all data rows that are currently - attached to this dataset. - - Note: For efficiency, the data are cached for 30 minutes. Newly created data rows will not appear - until the end of the cache period. - - Args: - timeout_seconds (float): Max waiting time, in seconds. - include_metadata (bool): True to return related DataRow metadata - Returns: - Generator that yields DataRow objects belonging to this dataset. - Raises: - LabelboxError: if the export fails or is unable to download within the specified time. - """ - warnings.warn( - "You are currently utilizing exports v1 for this action, which will be deprecated after April 30th, 2024. We recommend transitioning to exports v2. To view export v2 details, visit our docs: https://docs.labelbox.com/reference/label-export", - DeprecationWarning) - id_param = "datasetId" - metadata_param = "includeMetadataInput" - query_str = """mutation GetDatasetDataRowsExportUrlPyApi($%s: ID!, $%s: Boolean!) - {exportDatasetDataRows(data:{datasetId: $%s , includeMetadataInput: $%s}) {downloadUrl createdAt status}} - """ % (id_param, metadata_param, id_param, metadata_param) - sleep_time = 2 - while True: - res = self.client.execute(query_str, { - id_param: self.uid, - metadata_param: include_metadata - }) - res = res["exportDatasetDataRows"] - if res["status"] == "COMPLETE": - download_url = res["downloadUrl"] - response = requests.get(download_url) - response.raise_for_status() - reader = parser.reader(StringIO(response.text)) - return ( - Entity.DataRow(self.client, result) for result in reader) - elif res["status"] == "FAILED": - raise LabelboxError("Data row export failed.") - - timeout_seconds -= sleep_time - if timeout_seconds <= 0: - raise LabelboxError( - f"Unable to export data rows within {timeout_seconds} seconds." - ) - - logger.debug("Dataset '%s' data row export, waiting for server...", - self.uid) - time.sleep(sleep_time) - - @experimental - def export( - self, - task_name: Optional[str] = None, - filters: Optional[DatasetExportFilters] = None, - params: Optional[CatalogExportParams] = None, - ) -> ExportTask: - """ - Creates a dataset export task with the given params and returns the task. - - >>> dataset = client.get_dataset(DATASET_ID) - >>> task = dataset.export( - >>> filters={ - >>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - >>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - >>> "data_row_ids": [DATA_ROW_ID_1, DATA_ROW_ID_2, ...] # or global_keys: [DATA_ROW_GLOBAL_KEY_1, DATA_ROW_GLOBAL_KEY_2, ...] - >>> }, - >>> params={ - >>> "performance_details": False, - >>> "label_details": True - >>> }) - >>> task.wait_till_done() - >>> task.result - """ - task = self._export(task_name, filters, params, streamable=True) - return ExportTask(task) - - def export_v2( - self, - task_name: Optional[str] = None, - filters: Optional[DatasetExportFilters] = None, - params: Optional[CatalogExportParams] = None, - ) -> Task: - """ - Creates a dataset export task with the given params and returns the task. - - >>> dataset = client.get_dataset(DATASET_ID) - >>> task = dataset.export_v2( - >>> filters={ - >>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - >>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - >>> "data_row_ids": [DATA_ROW_ID_1, DATA_ROW_ID_2, ...] # or global_keys: [DATA_ROW_GLOBAL_KEY_1, DATA_ROW_GLOBAL_KEY_2, ...] - >>> }, - >>> params={ - >>> "performance_details": False, - >>> "label_details": True - >>> }) - >>> task.wait_till_done() - >>> task.result - """ - return self._export(task_name, filters, params) - - def _export( - self, - task_name: Optional[str] = None, - filters: Optional[DatasetExportFilters] = None, - params: Optional[CatalogExportParams] = None, - streamable: bool = False, - ) -> Task: - _params = params or CatalogExportParams({ - "attachments": False, - "metadata_fields": False, - "data_row_details": False, - "project_details": False, - "performance_details": False, - "label_details": False, - "media_type_override": None, - "model_run_ids": None, - "project_ids": None, - "interpolated_frames": False, - "all_projects": False, - "all_model_runs": False, - }) - validate_catalog_export_params(_params) - - _filters = filters or DatasetExportFilters({ - "last_activity_at": None, - "label_created_at": None, - "data_row_ids": None, - "global_keys": None, - }) - - mutation_name = "exportDataRowsInCatalog" - create_task_query_str = ( - f"mutation {mutation_name}PyApi" - f"($input: ExportDataRowsInCatalogInput!)" - f"{{{mutation_name}(input: $input){{taskId}}}}") - media_type_override = _params.get('media_type_override', None) - - if task_name is None: - task_name = f"Export v2: dataset - {self.name}" - query_params: Dict[str, Any] = { - "input": { - "taskName": task_name, - "filters": { - "searchQuery": { - "scope": None, - "query": None, - } - }, - "params": { - "mediaTypeOverride": - media_type_override.value - if media_type_override is not None else None, - "includeAttachments": - _params.get('attachments', False), - "includeMetadata": - _params.get('metadata_fields', False), - "includeDataRowDetails": - _params.get('data_row_details', False), - "includeProjectDetails": - _params.get('project_details', False), - "includePerformanceDetails": - _params.get('performance_details', False), - "includeLabelDetails": - _params.get('label_details', False), - "includeInterpolatedFrames": - _params.get('interpolated_frames', False), - "projectIds": - _params.get('project_ids', None), - "modelRunIds": - _params.get('model_run_ids', None), - "allProjects": - _params.get('all_projects', False), - "allModelRuns": - _params.get('all_model_runs', False), - }, - "streamable": streamable, - } - } - - search_query = build_filters(self.client, _filters) - search_query.append({ - "ids": [self.uid], - "operator": "is", - "type": "dataset" - }) - - query_params["input"]["filters"]["searchQuery"]["query"] = search_query - - res = self.client.execute(create_task_query_str, - query_params, - error_log_key="errors") - res = res[mutation_name] - task_id = res["taskId"] - return Task.get_task(self.client, task_id) - ----- -labelbox/schema/consensus_settings.py -from labelbox.utils import _CamelCaseMixin - - -class ConsensusSettings(_CamelCaseMixin): - """Container for holding consensus quality settings - - >>> ConsensusSettings( - >>> number_of_labels = 2, - >>> coverage_percentage = 0.2 - >>> ) - - Args: - number_of_labels: Number of labels for consensus - coverage_percentage: Percentage of data rows to be labeled more than once - """ - - number_of_labels: int - coverage_percentage: float - ----- -labelbox/schema/labeling_frontend.py -from labelbox.orm.db_object import DbObject -from labelbox.orm.model import Field, Relationship - - -class LabelingFrontend(DbObject): - """ Label editor. - - Represents an HTML / JavaScript UI that is used to generate - labels. “Editor” is the default Labeling Frontend that comes in every - organization. You can create new labeling frontends for an organization. - - Attributes: - name (str) - description (str) - iframe_url_path (str) - - projects (Relationship): `ToMany` relationship to Project - """ - name = Field.String("name") - description = Field.String("description") - iframe_url_path = Field.String("iframe_url_path") - - -class LabelingFrontendOptions(DbObject): - """ Label interface options. - - Attributes: - customization_options (str) - - project (Relationship): `ToOne` relationship to Project - labeling_frontend (Relationship): `ToOne` relationship to LabelingFrontend - organization (Relationship): `ToOne` relationship to Organization - """ - customization_options = Field.String("customization_options") - - project = Relationship.ToOne("Project") - labeling_frontend = Relationship.ToOne("LabelingFrontend") - organization = Relationship.ToOne("Organization") - ----- -labelbox/schema/data_row.py -import logging -from typing import TYPE_CHECKING, List, Optional, Union -import json - -from labelbox.orm import query -from labelbox.orm.db_object import DbObject, Updateable, BulkDeletable, experimental -from labelbox.orm.model import Entity, Field, Relationship -from labelbox.schema.data_row_metadata import DataRowMetadataField # type: ignore -from labelbox.schema.export_filters import DatarowExportFilters, build_filters, validate_at_least_one_of_data_row_ids_or_global_keys -from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params -from labelbox.schema.export_task import ExportTask -from labelbox.schema.task import Task - -if TYPE_CHECKING: - from labelbox import AssetAttachment, Client - -logger = logging.getLogger(__name__) - - -class DataRow(DbObject, Updateable, BulkDeletable): - """ Internal Labelbox representation of a single piece of data (e.g. image, video, text). - - Attributes: - external_id (str): User-generated file name or identifier - global_key (str): User-generated globally unique identifier - row_data (str): Paths to local files are uploaded to Labelbox's server. - Otherwise, it's treated as an external URL. - updated_at (datetime) - created_at (datetime) - media_attributes (dict): generated media attributes for the data row - metadata_fields (list): metadata associated with the data row - metadata (list): metadata associated with the data row as list of DataRowMetadataField. - When importing Data Rows with metadata, use `metadata_fields` instead - - dataset (Relationship): `ToOne` relationship to Dataset - created_by (Relationship): `ToOne` relationship to User - organization (Relationship): `ToOne` relationship to Organization - labels (Relationship): `ToMany` relationship to Label - attachments (Relationship) `ToMany` relationship with AssetAttachment - """ - external_id = Field.String("external_id") - global_key = Field.String("global_key") - row_data = Field.String("row_data") - updated_at = Field.DateTime("updated_at") - created_at = Field.DateTime("created_at") - media_attributes = Field.Json("media_attributes") - metadata_fields = Field.List( - dict, - graphql_type="DataRowCustomMetadataUpsertInput!", - name="metadata_fields", - result_subquery="metadataFields { schemaId name value kind }") - metadata = Field.List(DataRowMetadataField, - name="metadata", - graphql_name="customMetadata", - result_subquery="customMetadata { schemaId value }") - - # Relationships - dataset = Relationship.ToOne("Dataset") - created_by = Relationship.ToOne("User", False, "created_by") - organization = Relationship.ToOne("Organization", False) - labels = Relationship.ToMany("Label", True) - attachments = Relationship.ToMany("AssetAttachment", False, "attachments") - - supported_meta_types = supported_attachment_types = set( - Entity.AssetAttachment.AttachmentType.__members__) - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.attachments.supports_filtering = False - self.attachments.supports_sorting = False - - def update(self, **kwargs): - # Convert row data to string if it is an object - # All other updates pass through - primary_fields = ["external_id", "global_key", "row_data"] - for field in primary_fields: - data = kwargs.get(field) - if data == "" or data == {}: - raise ValueError(f"{field} cannot be empty if it is set") - if not any(kwargs.get(field) for field in primary_fields): - raise ValueError( - f"At least one of these fields needs to be present: {primary_fields}" - ) - - row_data = kwargs.get("row_data") - if isinstance(row_data, dict): - kwargs['row_data'] = json.dumps(row_data) - super().update(**kwargs) - - @staticmethod - def bulk_delete(data_rows) -> None: - """ Deletes all the given DataRows. - - Args: - data_rows (list of DataRow): The DataRows to delete. - """ - BulkDeletable._bulk_delete(data_rows, True) - - def get_winning_label_id(self, project_id: str) -> Optional[str]: - """ Retrieves the winning label ID, i.e. the one that was marked as the - best for a particular data row, in a project's workflow. - - Args: - project_id (str): ID of the project containing the data row - """ - data_row_id_param = "dataRowId" - project_id_param = "projectId" - query_str = """query GetWinningLabelIdPyApi($%s: ID!, $%s: ID!) { - dataRow(where: { id: $%s }) { - labelingActivity(where: { projectId: $%s }) { - selectedLabelId - } - }} """ % (data_row_id_param, project_id_param, data_row_id_param, - project_id_param) - - res = self.client.execute(query_str, { - data_row_id_param: self.uid, - project_id_param: project_id, - }) - - return res["dataRow"]["labelingActivity"]["selectedLabelId"] - - def create_attachment(self, - attachment_type, - attachment_value, - attachment_name=None) -> "AssetAttachment": - """ Adds an AssetAttachment to a DataRow. - Labelers can view these attachments while labeling. - - >>> datarow.create_attachment("TEXT", "This is a text message") - - Args: - attachment_type (str): Asset attachment type, must be one of: - VIDEO, IMAGE, TEXT, IMAGE_OVERLAY (AssetAttachment.AttachmentType) - attachment_value (str): Asset attachment value. - attachment_name (str): (Optional) Asset attachment name. - Returns: - `AssetAttachment` DB object. - Raises: - ValueError: asset_type must be one of the supported types. - """ - Entity.AssetAttachment.validate_attachment_type(attachment_type) - - attachment_type_param = "type" - attachment_value_param = "value" - attachment_name_param = "name" - data_row_id_param = "dataRowId" - - query_str = """mutation CreateDataRowAttachmentPyApi( - $%s: AttachmentType!, $%s: String!, $%s: String, $%s: ID!) { - createDataRowAttachment(data: { - type: $%s value: $%s name: $%s dataRowId: $%s}) {%s}} """ % ( - attachment_type_param, attachment_value_param, - attachment_name_param, data_row_id_param, attachment_type_param, - attachment_value_param, attachment_name_param, data_row_id_param, - query.results_query_part(Entity.AssetAttachment)) - - res = self.client.execute( - query_str, { - attachment_type_param: attachment_type, - attachment_value_param: attachment_value, - attachment_name_param: attachment_name, - data_row_id_param: self.uid - }) - return Entity.AssetAttachment(self.client, - res["createDataRowAttachment"]) - - @experimental - @staticmethod - def export( - client: "Client", - data_rows: Optional[List[Union[str, "DataRow"]]] = None, - global_keys: Optional[List[str]] = None, - task_name: Optional[str] = None, - params: Optional[CatalogExportParams] = None, - ) -> ExportTask: - """ - Creates a data rows export task with the given list, params and returns the task. - Args: - client (Client): client to use to make the export request - data_rows (list of DataRow or str): list of data row objects or data row ids to export - task_name (str): name of remote task - params (CatalogExportParams): export params - - >>> dataset = client.get_dataset(DATASET_ID) - >>> task = DataRow.export( - >>> data_rows=[data_row.uid for data_row in dataset.data_rows.list()], - >>> # or a list of DataRow objects: data_rows = data_set.data_rows.list() - >>> # or a list of global_keys=["global_key_1", "global_key_2"], - >>> # Note that exactly one of: data_rows or global_keys parameters can be passed in at a time - >>> # and if data rows ids is present, global keys will be ignored - >>> params={ - >>> "performance_details": False, - >>> "label_details": True - >>> }) - >>> task.wait_till_done() - >>> task.result - """ - task = DataRow._export(client, - data_rows, - global_keys, - task_name, - params, - streamable=True) - return ExportTask(task) - - @staticmethod - def export_v2( - client: "Client", - data_rows: Optional[List[Union[str, "DataRow"]]] = None, - global_keys: Optional[List[str]] = None, - task_name: Optional[str] = None, - params: Optional[CatalogExportParams] = None, - ) -> Task: - """ - Creates a data rows export task with the given list, params and returns the task. - Args: - client (Client): client to use to make the export request - data_rows (list of DataRow or str): list of data row objects or data row ids to export - task_name (str): name of remote task - params (CatalogExportParams): export params - - - >>> dataset = client.get_dataset(DATASET_ID) - >>> task = DataRow.export_v2( - >>> data_rows=[data_row.uid for data_row in dataset.data_rows.list()], - >>> # or a list of DataRow objects: data_rows = data_set.data_rows.list() - >>> # or a list of global_keys=["global_key_1", "global_key_2"], - >>> # Note that exactly one of: data_rows or global_keys parameters can be passed in at a time - >>> # and if data rows ids is present, global keys will be ignored - >>> params={ - >>> "performance_details": False, - >>> "label_details": True - >>> }) - >>> task.wait_till_done() - >>> task.result - """ - return DataRow._export(client, data_rows, global_keys, task_name, - params) - - @staticmethod - def _export( - client: "Client", - data_rows: Optional[List[Union[str, "DataRow"]]] = None, - global_keys: Optional[List[str]] = None, - task_name: Optional[str] = None, - params: Optional[CatalogExportParams] = None, - streamable: bool = False, - ) -> Task: - _params = params or CatalogExportParams({ - "attachments": False, - "metadata_fields": False, - "data_row_details": False, - "project_details": False, - "performance_details": False, - "label_details": False, - "media_type_override": None, - "model_run_ids": None, - "project_ids": None, - "interpolated_frames": False, - "all_projects": False, - "all_model_runs": False, - }) - - validate_catalog_export_params(_params) - - mutation_name = "exportDataRowsInCatalog" - create_task_query_str = ( - f"mutation {mutation_name}PyApi" - f"($input: ExportDataRowsInCatalogInput!)" - f"{{{mutation_name}(input: $input){{taskId}}}}") - - data_row_ids = [] - if data_rows is not None: - for dr in data_rows: - if isinstance(dr, DataRow): - data_row_ids.append(dr.uid) - elif isinstance(dr, str): - data_row_ids.append(dr) - - filters = DatarowExportFilters({ - "data_row_ids": data_row_ids, - "global_keys": None, - }) if data_row_ids else DatarowExportFilters({ - "data_row_ids": None, - "global_keys": global_keys, - }) - validate_at_least_one_of_data_row_ids_or_global_keys(filters) - - search_query = build_filters(client, filters) - media_type_override = _params.get('media_type_override', None) - - if task_name is None: - task_name = f"Export v2: data rows {len(data_row_ids)}" - query_params = { - "input": { - "taskName": task_name, - "filters": { - "searchQuery": { - "scope": None, - "query": search_query - } - }, - "params": { - "mediaTypeOverride": - media_type_override.value - if media_type_override is not None else None, - "includeAttachments": - _params.get('attachments', False), - "includeMetadata": - _params.get('metadata_fields', False), - "includeDataRowDetails": - _params.get('data_row_details', False), - "includeProjectDetails": - _params.get('project_details', False), - "includePerformanceDetails": - _params.get('performance_details', False), - "includeLabelDetails": - _params.get('label_details', False), - "includeInterpolatedFrames": - _params.get('interpolated_frames', False), - "projectIds": - _params.get('project_ids', None), - "modelRunIds": - _params.get('model_run_ids', None), - "allProjects": - _params.get('all_projects', False), - "allModelRuns": - _params.get('all_model_runs', False), - }, - "streamable": streamable - } - } - - res = client.execute(create_task_query_str, - query_params, - error_log_key="errors") - print(res) - res = res[mutation_name] - task_id = res["taskId"] - return Task.get_task(client, task_id) - ----- -labelbox/schema/id_type.py -from enum import Enum - - -class IdType(str, Enum): - """ - The type of id used to identify a data row. - - Currently supported types are: - - DataRowId: The id assigned to a data row by Labelbox. - - GlobalKey: The id assigned to a data row by the user. - """ - DataRowId = "ID" - GlobalKey = "GKEY" - ----- -labelbox/schema/conflict_resolution_strategy.py -from enum import Enum - - -class ConflictResolutionStrategy(str, Enum): - KeepExisting = "KEEP_EXISTING" - OverrideWithAnnotations = "OVERRIDE_WITH_ANNOTATIONS" - OverrideWithPredictions = "OVERRIDE_WITH_PREDICTIONS" - - @staticmethod - def from_str(label: str) -> "ConflictResolutionStrategy": - return ConflictResolutionStrategy[label] - ----- -labelbox/schema/label.py -from typing import TYPE_CHECKING - -from labelbox.orm import query -from labelbox.orm.db_object import DbObject, Updateable, BulkDeletable -from labelbox.orm.model import Entity, Field, Relationship - -if TYPE_CHECKING: - from labelbox import Benchmark, Review -""" Client-side object type definitions. """ - - -class Label(DbObject, Updateable, BulkDeletable): - """ Label represents an assessment on a DataRow. For example one label could - contain 100 bounding boxes (annotations). - - Attributes: - label (str) - seconds_to_label (float) - agreement (float) - benchmark_agreement (float) - is_benchmark_reference (bool) - - project (Relationship): `ToOne` relationship to Project - data_row (Relationship): `ToOne` relationship to DataRow - reviews (Relationship): `ToMany` relationship to Review - created_by (Relationship): `ToOne` relationship to User - - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.reviews.supports_filtering = False - - label = Field.String("label") - seconds_to_label = Field.Float("seconds_to_label") - agreement = Field.Float("agreement") - benchmark_agreement = Field.Float("benchmark_agreement") - is_benchmark_reference = Field.Boolean("is_benchmark_reference") - updated_at = Field.DateTime("updated_at") - created_at = Field.DateTime("created_at") - - project = Relationship.ToOne("Project") - data_row = Relationship.ToOne("DataRow") - reviews = Relationship.ToMany("Review", False) - created_by = Relationship.ToOne("User", False, "created_by") - - @staticmethod - def bulk_delete(labels) -> None: - """ Deletes all the given Labels. - - Args: - labels (list of Label): The Labels to delete. - """ - BulkDeletable._bulk_delete(labels, False) - - def create_review(self, **kwargs) -> "Review": - """ Creates a Review for this label. - - Args: - **kwargs: Review attributes. At a minimum, a `Review.score` field value must be provided. - """ - kwargs[Entity.Review.label.name] = self - kwargs[Entity.Review.project.name] = self.project() - return self.client._create(Entity.Review, kwargs) - - def create_benchmark(self) -> "Benchmark": - """ Creates a Benchmark for this Label. - - Returns: - The newly created Benchmark. - """ - label_id_param = "labelId" - query_str = """mutation CreateBenchmarkPyApi($%s: ID!) { - createBenchmark(data: {labelId: $%s}) {%s}} """ % ( - label_id_param, label_id_param, - query.results_query_part(Entity.Benchmark)) - res = self.client.execute(query_str, {label_id_param: self.uid}) - return Entity.Benchmark(self.client, res["createBenchmark"]) - ----- -labelbox/schema/bulk_import_request.py -import json -import time -from uuid import UUID, uuid4 -import functools - -import logging -from pathlib import Path -from google.api_core import retry -from labelbox import parser -import requests -from labelbox import pydantic_compat -from typing_extensions import Literal -from typing import (Any, List, Optional, BinaryIO, Dict, Iterable, Tuple, Union, - Type, Set, TYPE_CHECKING) - -from labelbox import exceptions as lb_exceptions -from labelbox.orm.model import Entity -from labelbox import utils -from labelbox.orm import query -from labelbox.orm.db_object import DbObject -from labelbox.orm.model import Field, Relationship -from labelbox.schema.enums import BulkImportRequestState -from labelbox.schema.serialization import serialize_labels - -if TYPE_CHECKING: - from labelbox import Project - from labelbox.types import Label - -NDJSON_MIME_TYPE = "application/x-ndjson" -logger = logging.getLogger(__name__) - - -def _make_file_name(project_id: str, name: str) -> str: - return f"{project_id}__{name}.ndjson" - - -# TODO(gszpak): move it to client.py -def _make_request_data(project_id: str, name: str, content_length: int, - file_name: str) -> dict: - query_str = """mutation createBulkImportRequestFromFilePyApi( - $projectId: ID!, $name: String!, $file: Upload!, $contentLength: Int!) { - createBulkImportRequest(data: { - projectId: $projectId, - name: $name, - filePayload: { - file: $file, - contentLength: $contentLength - } - }) { - %s - } - } - """ % query.results_query_part(BulkImportRequest) - variables = { - "projectId": project_id, - "name": name, - "file": None, - "contentLength": content_length - } - operations = json.dumps({"variables": variables, "query": query_str}) - - return { - "operations": operations, - "map": (None, json.dumps({file_name: ["variables.file"]})) - } - - -def _send_create_file_command( - client, request_data: dict, file_name: str, - file_data: Tuple[str, Union[bytes, BinaryIO], str]) -> dict: - - response = client.execute(data=request_data, files={file_name: file_data}) - - if not response.get("createBulkImportRequest", None): - raise lb_exceptions.LabelboxError( - "Failed to create BulkImportRequest, message: %s" % - response.get("errors", None) or response.get("error", None)) - - return response - - -class BulkImportRequest(DbObject): - """Represents the import job when importing annotations. - - Attributes: - name (str) - state (Enum): FAILED, RUNNING, or FINISHED (Refers to the whole import job) - input_file_url (str): URL to your web-hosted NDJSON file - error_file_url (str): NDJSON that contains error messages for failed annotations - status_file_url (str): NDJSON that contains status for each annotation - created_at (datetime): UTC timestamp for date BulkImportRequest was created - - project (Relationship): `ToOne` relationship to Project - created_by (Relationship): `ToOne` relationship to User - """ - name = Field.String("name") - state = Field.Enum(BulkImportRequestState, "state") - input_file_url = Field.String("input_file_url") - error_file_url = Field.String("error_file_url") - status_file_url = Field.String("status_file_url") - created_at = Field.DateTime("created_at") - - project = Relationship.ToOne("Project") - created_by = Relationship.ToOne("User", False, "created_by") - - @property - def inputs(self) -> List[Dict[str, Any]]: - """ - Inputs for each individual annotation uploaded. - This should match the ndjson annotations that you have uploaded. - - Returns: - Uploaded ndjson. - - * This information will expire after 24 hours. - """ - return self._fetch_remote_ndjson(self.input_file_url) - - @property - def errors(self) -> List[Dict[str, Any]]: - """ - Errors for each individual annotation uploaded. This is a subset of statuses - - Returns: - List of dicts containing error messages. Empty list means there were no errors - See `BulkImportRequest.statuses` for more details. - - * This information will expire after 24 hours. - """ - self.wait_until_done() - return self._fetch_remote_ndjson(self.error_file_url) - - @property - def statuses(self) -> List[Dict[str, Any]]: - """ - Status for each individual annotation uploaded. - - Returns: - A status for each annotation if the upload is done running. - See below table for more details - - .. list-table:: - :widths: 15 150 - :header-rows: 1 - - * - Field - - Description - * - uuid - - Specifies the annotation for the status row. - * - dataRow - - JSON object containing the Labelbox data row ID for the annotation. - * - status - - Indicates SUCCESS or FAILURE. - * - errors - - An array of error messages included when status is FAILURE. Each error has a name, message and optional (key might not exist) additional_info. - - * This information will expire after 24 hours. - """ - self.wait_until_done() - return self._fetch_remote_ndjson(self.status_file_url) - - @functools.lru_cache() - def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]: - """ - Fetches the remote ndjson file and caches the results. - - Args: - url (str): Can be any url pointing to an ndjson file. - Returns: - ndjson as a list of dicts. - """ - response = requests.get(url) - response.raise_for_status() - return parser.loads(response.text) - - def refresh(self) -> None: - """Synchronizes values of all fields with the database. - """ - query_str, params = query.get_single(BulkImportRequest, self.uid) - res = self.client.execute(query_str, params) - res = res[utils.camel_case(BulkImportRequest.type_name())] - self._set_field_values(res) - - def wait_until_done(self, sleep_time_seconds: int = 5) -> None: - """Blocks import job until certain conditions are met. - - Blocks until the BulkImportRequest.state changes either to - `BulkImportRequestState.FINISHED` or `BulkImportRequestState.FAILED`, - periodically refreshing object's state. - - Args: - sleep_time_seconds (str): a time to block between subsequent API calls - """ - while self.state == BulkImportRequestState.RUNNING: - logger.info(f"Sleeping for {sleep_time_seconds} seconds...") - time.sleep(sleep_time_seconds) - self.__exponential_backoff_refresh() - - @retry.Retry(predicate=retry.if_exception_type(lb_exceptions.ApiLimitError, - lb_exceptions.TimeoutError, - lb_exceptions.NetworkError)) - def __exponential_backoff_refresh(self) -> None: - self.refresh() - - @classmethod - def from_name(cls, client, project_id: str, - name: str) -> 'BulkImportRequest': - """ Fetches existing BulkImportRequest. - - Args: - client (Client): a Labelbox client - project_id (str): BulkImportRequest's project id - name (str): name of BulkImportRequest - Returns: - BulkImportRequest object - - """ - query_str = """query getBulkImportRequestPyApi( - $projectId: ID!, $name: String!) { - bulkImportRequest(where: { - projectId: $projectId, - name: $name - }) { - %s - } - } - """ % query.results_query_part(cls) - params = {"projectId": project_id, "name": name} - response = client.execute(query_str, params=params) - return cls(client, response['bulkImportRequest']) - - @classmethod - def create_from_url(cls, - client, - project_id: str, - name: str, - url: str, - validate=True) -> 'BulkImportRequest': - """ - Creates a BulkImportRequest from a publicly accessible URL - to an ndjson file with predictions. - - Args: - client (Client): a Labelbox client - project_id (str): id of project for which predictions will be imported - name (str): name of BulkImportRequest - url (str): publicly accessible URL pointing to ndjson file containing predictions - validate (bool): a flag indicating if there should be a validation - if `url` is valid ndjson - Returns: - BulkImportRequest object - """ - if validate: - logger.warn( - "Validation is turned on. The file will be downloaded locally and processed before uploading." - ) - res = requests.get(url) - data = parser.loads(res.text) - _validate_ndjson(data, client.get_project(project_id)) - - query_str = """mutation createBulkImportRequestPyApi( - $projectId: ID!, $name: String!, $fileUrl: String!) { - createBulkImportRequest(data: { - projectId: $projectId, - name: $name, - fileUrl: $fileUrl - }) { - %s - } - } - """ % query.results_query_part(cls) - params = {"projectId": project_id, "name": name, "fileUrl": url} - bulk_import_request_response = client.execute(query_str, params=params) - return cls(client, - bulk_import_request_response["createBulkImportRequest"]) - - @classmethod - def create_from_objects(cls, - client, - project_id: str, - name: str, - predictions: Union[Iterable[Dict], - Iterable["Label"]], - validate=True) -> 'BulkImportRequest': - """ - Creates a `BulkImportRequest` from an iterable of dictionaries. - - Conforms to JSON predictions format, e.g.: - ``{ - "uuid": "9fd9a92e-2560-4e77-81d4-b2e955800092", - "schemaId": "ckappz7d700gn0zbocmqkwd9i", - "dataRow": { - "id": "ck1s02fqxm8fi0757f0e6qtdc" - }, - "bbox": { - "top": 48, - "left": 58, - "height": 865, - "width": 1512 - } - }`` - - Args: - client (Client): a Labelbox client - project_id (str): id of project for which predictions will be imported - name (str): name of BulkImportRequest - predictions (Iterable[dict]): iterable of dictionaries representing predictions - validate (bool): a flag indicating if there should be a validation - if `predictions` is valid ndjson - Returns: - BulkImportRequest object - """ - if not isinstance(predictions, list): - raise TypeError( - f"annotations must be in a form of Iterable. Found {type(predictions)}" - ) - ndjson_predictions = serialize_labels(predictions) - - if validate: - _validate_ndjson(ndjson_predictions, client.get_project(project_id)) - - data_str = parser.dumps(ndjson_predictions) - if not data_str: - raise ValueError('annotations cannot be empty') - - data = data_str.encode('utf-8') - file_name = _make_file_name(project_id, name) - request_data = _make_request_data(project_id, name, len(data_str), - file_name) - file_data = (file_name, data, NDJSON_MIME_TYPE) - response_data = _send_create_file_command(client, - request_data=request_data, - file_name=file_name, - file_data=file_data) - - return cls(client, response_data["createBulkImportRequest"]) - - @classmethod - def create_from_local_file(cls, - client, - project_id: str, - name: str, - file: Path, - validate_file=True) -> 'BulkImportRequest': - """ - Creates a BulkImportRequest from a local ndjson file with predictions. - - Args: - client (Client): a Labelbox client - project_id (str): id of project for which predictions will be imported - name (str): name of BulkImportRequest - file (Path): local ndjson file with predictions - validate_file (bool): a flag indicating if there should be a validation - if `file` is a valid ndjson file - Returns: - BulkImportRequest object - - """ - file_name = _make_file_name(project_id, name) - content_length = file.stat().st_size - request_data = _make_request_data(project_id, name, content_length, - file_name) - - with file.open('rb') as f: - if validate_file: - reader = parser.reader(f) - # ensure that the underlying json load call is valid - # https://github.com/rhgrant10/ndjson/blob/ff2f03c56b21f28f7271b27da35ca4a8bf9a05d0/ndjson/api.py#L53 - # by iterating through the file so we only store - # each line in memory rather than the entire file - try: - _validate_ndjson(reader, client.get_project(project_id)) - except ValueError: - raise ValueError(f"{file} is not a valid ndjson file") - else: - f.seek(0) - file_data = (file.name, f, NDJSON_MIME_TYPE) - response_data = _send_create_file_command(client, request_data, - file_name, file_data) - return cls(client, response_data["createBulkImportRequest"]) - - def delete(self) -> None: - """ Deletes the import job and also any annotations created by this import. - - Returns: - None - """ - id_param = "bulk_request_id" - query_str = """mutation deleteBulkImportRequestPyApi($%s: ID!) { - deleteBulkImportRequest(where: {id: $%s}) { - id - name - } - }""" % (id_param, id_param) - self.client.execute(query_str, {id_param: self.uid}) - - -def _validate_ndjson(lines: Iterable[Dict[str, Any]], - project: "Project") -> None: - """ - Client side validation of an ndjson object. - - Does not guarentee that an upload will succeed for the following reasons: - * We are not checking the data row types which will cause the following errors to slip through - * Missing frame indices will not causes an error for videos - * Uploaded annotations for the wrong data type will pass (Eg. entity on images) - * We are not checking bounds of an asset (Eg. frame index, image height, text location) - - Args: - lines (Iterable[Dict[str,Any]]): An iterable of ndjson lines - project (Project): id of project for which predictions will be imported - - Raises: - MALValidationError: Raise for invalid NDJson - UuidError: Duplicate UUID in upload - """ - feature_schemas_by_id, feature_schemas_by_name = get_mal_schemas( - project.ontology()) - uids: Set[str] = set() - for idx, line in enumerate(lines): - try: - annotation = NDAnnotation(**line) - annotation.validate_instance(feature_schemas_by_id, - feature_schemas_by_name) - uuid = str(annotation.uuid) - if uuid in uids: - raise lb_exceptions.UuidError( - f'{uuid} already used in this import job, ' - 'must be unique for the project.') - uids.add(uuid) - except (pydantic_compat.ValidationError, ValueError, TypeError, - KeyError) as e: - raise lb_exceptions.MALValidationError( - f"Invalid NDJson on line {idx}") from e - - -#The rest of this file contains objects for MAL validation -def parse_classification(tool): - """ - Parses a classification from an ontology. Only radio, checklist, and text are supported for mal - - Args: - tool (dict) - - Returns: - dict - """ - if tool['type'] in ['radio', 'checklist']: - option_schema_ids = [r['featureSchemaId'] for r in tool['options']] - option_names = [r['value'] for r in tool['options']] - return { - 'tool': tool['type'], - 'featureSchemaId': tool['featureSchemaId'], - 'name': tool['name'], - 'options': [*option_schema_ids, *option_names] - } - elif tool['type'] == 'text': - return { - 'tool': tool['type'], - 'name': tool['name'], - 'featureSchemaId': tool['featureSchemaId'] - } - - -def get_mal_schemas(ontology): - """ - Converts a project ontology to a dict for easier lookup during ndjson validation - - Args: - ontology (Ontology) - Returns: - Dict, Dict : Useful for looking up a tool from a given feature schema id or name - """ - - valid_feature_schemas_by_schema_id = {} - valid_feature_schemas_by_name = {} - for tool in ontology.normalized['tools']: - classifications = [ - parse_classification(classification_tool) - for classification_tool in tool['classifications'] - ] - classifications_by_schema_id = { - v['featureSchemaId']: v for v in classifications - } - classifications_by_name = {v['name']: v for v in classifications} - valid_feature_schemas_by_schema_id[tool['featureSchemaId']] = { - 'tool': tool['tool'], - 'classificationsBySchemaId': classifications_by_schema_id, - 'classificationsByName': classifications_by_name, - 'name': tool['name'] - } - valid_feature_schemas_by_name[tool['name']] = { - 'tool': tool['tool'], - 'classificationsBySchemaId': classifications_by_schema_id, - 'classificationsByName': classifications_by_name, - 'name': tool['name'] - } - for tool in ontology.normalized['classifications']: - valid_feature_schemas_by_schema_id[ - tool['featureSchemaId']] = parse_classification(tool) - valid_feature_schemas_by_name[tool['name']] = parse_classification(tool) - return valid_feature_schemas_by_schema_id, valid_feature_schemas_by_name - - -LabelboxID: str = pydantic_compat.Field(..., min_length=25, max_length=25) - - -class Bbox(pydantic_compat.BaseModel): - top: float - left: float - height: float - width: float - - -class Point(pydantic_compat.BaseModel): - x: float - y: float - - -class FrameLocation(pydantic_compat.BaseModel): - end: int - start: int - - -class VideoSupported(pydantic_compat.BaseModel): - #Note that frames are only allowed as top level inferences for video - frames: Optional[List[FrameLocation]] - - -#Base class for a special kind of union. -# Compatible with pydantic_compat. Improves error messages over a traditional union -class SpecialUnion: - - def __new__(cls, **kwargs): - return cls.build(kwargs) - - @classmethod - def __get_validators__(cls): - yield cls.build - - @classmethod - def get_union_types(cls): - if not issubclass(cls, SpecialUnion): - raise TypeError("{} must be a subclass of SpecialUnion") - - union_types = [x for x in cls.__orig_bases__ if hasattr(x, "__args__")] - if len(union_types) < 1: - raise TypeError( - "Class {cls} should inherit from a union of objects to build") - if len(union_types) > 1: - raise TypeError( - f"Class {cls} should inherit from exactly one union of objects to build. Found {union_types}" - ) - return union_types[0].__args__[0].__args__ - - @classmethod - def build(cls: Any, data: Union[dict, - pydantic_compat.BaseModel]) -> "NDBase": - """ - Checks through all objects in the union to see which matches the input data. - Args: - data (Union[dict, pydantic_compat.BaseModel]) : The data for constructing one of the objects in the union - raises: - KeyError: data does not contain the determinant fields for any of the types supported by this SpecialUnion - pydantic_compat.ValidationError: Error while trying to construct a specific object in the union - - """ - if isinstance(data, pydantic_compat.BaseModel): - data = data.dict() - - top_level_fields = [] - max_match = 0 - matched = None - - for type_ in cls.get_union_types(): - determinate_fields = type_.Config.determinants(type_) - top_level_fields.append(determinate_fields) - matches = sum([val in determinate_fields for val in data]) - if matches == len(determinate_fields) and matches > max_match: - max_match = matches - matched = type_ - - if matched is not None: - #These two have the exact same top level keys - if matched in [NDRadio, NDText]: - if isinstance(data['answer'], dict): - matched = NDRadio - elif isinstance(data['answer'], str): - matched = NDText - else: - raise TypeError( - f"Unexpected type for answer field. Found {data['answer']}. Expected a string or a dict" - ) - return matched(**data) - else: - raise KeyError( - f"Invalid annotation. Must have one of the following keys : {top_level_fields}. Found {data}." - ) - - @classmethod - def schema(cls): - results = {'definitions': {}} - for cl in cls.get_union_types(): - schema = cl.schema() - results['definitions'].update(schema.pop('definitions')) - results[cl.__name__] = schema - return results - - -class DataRow(pydantic_compat.BaseModel): - id: str - - -class NDFeatureSchema(pydantic_compat.BaseModel): - schemaId: Optional[str] = None - name: Optional[str] = None - - @pydantic_compat.root_validator - def must_set_one(cls, values): - if values['schemaId'] is None and values['name'] is None: - raise ValueError( - "Must set either schemaId or name for all feature schemas") - return values - - -class NDBase(NDFeatureSchema): - ontology_type: str - uuid: UUID - dataRow: DataRow - - def validate_feature_schemas(self, valid_feature_schemas_by_id, - valid_feature_schemas_by_name): - if self.name: - if self.name not in valid_feature_schemas_by_name: - raise ValueError( - f"Name {self.name} is not valid for the provided project's ontology." - ) - - if self.ontology_type != valid_feature_schemas_by_name[ - self.name]['tool']: - raise ValueError( - f"Name {self.name} does not map to the assigned tool {valid_feature_schemas_by_name[self.name]['tool']}" - ) - - if self.schemaId: - if self.schemaId not in valid_feature_schemas_by_id: - raise ValueError( - f"Schema id {self.schemaId} is not valid for the provided project's ontology." - ) - - if self.ontology_type != valid_feature_schemas_by_id[ - self.schemaId]['tool']: - raise ValueError( - f"Schema id {self.schemaId} does not map to the assigned tool {valid_feature_schemas_by_id[self.schemaId]['tool']}" - ) - - def validate_instance(self, valid_feature_schemas_by_id, - valid_feature_schemas_by_name): - self.validate_feature_schemas(valid_feature_schemas_by_id, - valid_feature_schemas_by_name) - - class Config: - #Users shouldn't to add extra data to the payload - extra = 'forbid' - - @staticmethod - def determinants(parent_cls) -> List[str]: - #This is a hack for better error messages - return [ - k for k, v in parent_cls.__fields__.items() - if 'determinant' in v.field_info.extra - ] - - -###### Classifications ###### - - -class NDText(NDBase): - ontology_type: Literal["text"] = "text" - answer: str = pydantic_compat.Field(determinant=True) - #No feature schema to check - - -class NDChecklist(VideoSupported, NDBase): - ontology_type: Literal["checklist"] = "checklist" - answers: List[NDFeatureSchema] = pydantic_compat.Field(determinant=True) - - @pydantic_compat.validator('answers', pre=True) - def validate_answers(cls, value, field): - #constr not working with mypy. - if not len(value): - raise ValueError("Checklist answers should not be empty") - return value - - def validate_feature_schemas(self, valid_feature_schemas_by_id, - valid_feature_schemas_by_name): - #Test top level feature schema for this tool - super(NDChecklist, - self).validate_feature_schemas(valid_feature_schemas_by_id, - valid_feature_schemas_by_name) - #Test the feature schemas provided to the answer field - if len(set([answer.name or answer.schemaId for answer in self.answers - ])) != len(self.answers): - raise ValueError( - f"Duplicated featureSchema found for checklist {self.uuid}") - for answer in self.answers: - options = valid_feature_schemas_by_name[ - self. - name]['options'] if self.name else valid_feature_schemas_by_id[ - self.schemaId]['options'] - if answer.name not in options and answer.schemaId not in options: - raise ValueError( - f"Feature schema provided to {self.ontology_type} invalid. Expected on of {options}. Found {answer}" - ) - - -class NDRadio(VideoSupported, NDBase): - ontology_type: Literal["radio"] = "radio" - answer: NDFeatureSchema = pydantic_compat.Field(determinant=True) - - def validate_feature_schemas(self, valid_feature_schemas_by_id, - valid_feature_schemas_by_name): - super(NDRadio, - self).validate_feature_schemas(valid_feature_schemas_by_id, - valid_feature_schemas_by_name) - options = valid_feature_schemas_by_name[ - self.name]['options'] if self.name else valid_feature_schemas_by_id[ - self.schemaId]['options'] - if self.answer.name not in options and self.answer.schemaId not in options: - raise ValueError( - f"Feature schema provided to {self.ontology_type} invalid. Expected on of {options}. Found {self.answer.name or self.answer.schemaId}" - ) - - -#A union with custom construction logic to improve error messages -class NDClassification( - SpecialUnion, - Type[Union[ # type: ignore - NDText, NDRadio, NDChecklist]]): - ... - - -###### Tools ###### - - -class NDBaseTool(NDBase): - classifications: List[NDClassification] = [] - - #This is indepdent of our problem - def validate_feature_schemas(self, valid_feature_schemas_by_id, - valid_feature_schemas_by_name): - super(NDBaseTool, - self).validate_feature_schemas(valid_feature_schemas_by_id, - valid_feature_schemas_by_name) - for classification in self.classifications: - classification.validate_feature_schemas( - valid_feature_schemas_by_name[ - self.name]['classificationsBySchemaId'] - if self.name else valid_feature_schemas_by_id[self.schemaId] - ['classificationsBySchemaId'], valid_feature_schemas_by_name[ - self.name]['classificationsByName'] - if self.name else valid_feature_schemas_by_id[ - self.schemaId]['classificationsByName']) - - @pydantic_compat.validator('classifications', pre=True) - def validate_subclasses(cls, value, field): - #Create uuid and datarow id so we don't have to define classification objects twice - #This is caused by the fact that we require these ids for top level classifications but not for subclasses - results = [] - dummy_id = 'child'.center(25, '_') - for row in value: - results.append({ - **row, 'dataRow': { - 'id': dummy_id - }, - 'uuid': str(uuid4()) - }) - return results - - -class NDPolygon(NDBaseTool): - ontology_type: Literal["polygon"] = "polygon" - polygon: List[Point] = pydantic_compat.Field(determinant=True) - - @pydantic_compat.validator('polygon') - def is_geom_valid(cls, v): - if len(v) < 3: - raise ValueError( - f"A polygon must have at least 3 points to be valid. Found {v}") - return v - - -class NDPolyline(NDBaseTool): - ontology_type: Literal["line"] = "line" - line: List[Point] = pydantic_compat.Field(determinant=True) - - @pydantic_compat.validator('line') - def is_geom_valid(cls, v): - if len(v) < 2: - raise ValueError( - f"A line must have at least 2 points to be valid. Found {v}") - return v - - -class NDRectangle(NDBaseTool): - ontology_type: Literal["rectangle"] = "rectangle" - bbox: Bbox = pydantic_compat.Field(determinant=True) - #Could check if points are positive - - -class NDPoint(NDBaseTool): - ontology_type: Literal["point"] = "point" - point: Point = pydantic_compat.Field(determinant=True) - #Could check if points are positive - - -class EntityLocation(pydantic_compat.BaseModel): - start: int - end: int - - -class NDTextEntity(NDBaseTool): - ontology_type: Literal["named-entity"] = "named-entity" - location: EntityLocation = pydantic_compat.Field(determinant=True) - - @pydantic_compat.validator('location') - def is_valid_location(cls, v): - if isinstance(v, pydantic_compat.BaseModel): - v = v.dict() - - if len(v) < 2: - raise ValueError( - f"A line must have at least 2 points to be valid. Found {v}") - if v['start'] < 0: - raise ValueError(f"Text location must be positive. Found {v}") - if v['start'] > v['end']: - raise ValueError( - f"Text start location must be less or equal than end. Found {v}" - ) - return v - - -class RLEMaskFeatures(pydantic_compat.BaseModel): - counts: List[int] - size: List[int] - - @pydantic_compat.validator('counts') - def validate_counts(cls, counts): - if not all([count >= 0 for count in counts]): - raise ValueError( - "Found negative value for counts. They should all be zero or positive" - ) - return counts - - @pydantic_compat.validator('size') - def validate_size(cls, size): - if len(size) != 2: - raise ValueError( - f"Mask `size` should have two ints representing height and with. Found : {size}" - ) - if not all([count > 0 for count in size]): - raise ValueError( - f"Mask `size` should be a postitive int. Found : {size}") - return size - - -class PNGMaskFeatures(pydantic_compat.BaseModel): - # base64 encoded png bytes - png: str - - -class URIMaskFeatures(pydantic_compat.BaseModel): - instanceURI: str - colorRGB: Union[List[int], Tuple[int, int, int]] - - @pydantic_compat.validator('colorRGB') - def validate_color(cls, colorRGB): - #Does the dtype matter? Can it be a float? - if not isinstance(colorRGB, (tuple, list)): - raise ValueError( - f"Received color that is not a list or tuple. Found : {colorRGB}" - ) - elif len(colorRGB) != 3: - raise ValueError( - f"Must provide RGB values for segmentation colors. Found : {colorRGB}" - ) - elif not all([0 <= color <= 255 for color in colorRGB]): - raise ValueError( - f"All rgb colors must be between 0 and 255. Found : {colorRGB}") - return colorRGB - - -class NDMask(NDBaseTool): - ontology_type: Literal["superpixel"] = "superpixel" - mask: Union[URIMaskFeatures, PNGMaskFeatures, - RLEMaskFeatures] = pydantic_compat.Field(determinant=True) - - -#A union with custom construction logic to improve error messages -class NDTool( - SpecialUnion, - Type[Union[ # type: ignore - NDMask, - NDTextEntity, - NDPoint, - NDRectangle, - NDPolyline, - NDPolygon, - ]]): - ... - - -class NDAnnotation( - SpecialUnion, - Type[Union[ # type: ignore - NDTool, NDClassification]]): - - @classmethod - def build(cls: Any, data) -> "NDBase": - if not isinstance(data, dict): - raise ValueError('value must be dict') - errors = [] - for cl in cls.get_union_types(): - try: - return cl(**data) - except KeyError as e: - errors.append(f"{cl.__name__}: {e}") - - raise ValueError('Unable to construct any annotation.\n{}'.format( - "\n".join(errors))) - - @classmethod - def schema(cls): - data = {'definitions': {}} - for type_ in cls.get_union_types(): - schema_ = type_.schema() - data['definitions'].update(schema_.pop('definitions')) - data[type_.__name__] = schema_ - return data - ----- -labelbox/schema/ontology.py -# type: ignore - -import colorsys -from dataclasses import dataclass, field -from enum import Enum -from typing import Any, Dict, List, Optional, Union, Type -import warnings - -from labelbox.exceptions import InconsistentOntologyException -from labelbox.orm.db_object import DbObject -from labelbox.orm.model import Field, Relationship -from labelbox import pydantic_compat -import json - -FeatureSchemaId: Type[str] = pydantic_compat.constr(min_length=25, - max_length=25) -SchemaId: Type[str] = pydantic_compat.constr(min_length=25, max_length=25) - - -class DeleteFeatureFromOntologyResult: - archived: bool - deleted: bool - - def __str__(self): - return "<%s %s>" % (self.__class__.__name__.split(".")[-1], - json.dumps(self.__dict__)) - - -class FeatureSchema(DbObject): - name = Field.String("name") - color = Field.String("name") - normalized = Field.Json("normalized") - - -@dataclass -class Option: - """ - An option is a possible answer within a Classification object in - a Project's ontology. - - To instantiate, only the "value" parameter needs to be passed in. - - Example(s): - option = Option(value = "Option Example") - - Attributes: - value: (str) - schema_id: (str) - feature_schema_id: (str) - options: (list) - """ - value: Union[str, int] - label: Optional[Union[str, int]] = None - schema_id: Optional[str] = None - feature_schema_id: Optional[FeatureSchemaId] = None - options: List["Classification"] = field(default_factory=list) - - def __post_init__(self): - if self.label is None: - self.label = self.value - - @classmethod - def from_dict( - cls, - dictionary: Dict[str, - Any]) -> Dict[Union[str, int], Union[str, int]]: - return cls(value=dictionary["value"], - label=dictionary["label"], - schema_id=dictionary.get("schemaNodeId", None), - feature_schema_id=dictionary.get("featureSchemaId", None), - options=[ - Classification.from_dict(o) - for o in dictionary.get("options", []) - ]) - - def asdict(self) -> Dict[str, Any]: - return { - "schemaNodeId": self.schema_id, - "featureSchemaId": self.feature_schema_id, - "label": self.label, - "value": self.value, - "options": [o.asdict(is_subclass=True) for o in self.options] - } - - def add_option(self, option: 'Classification') -> None: - if option.name in (o.name for o in self.options): - raise InconsistentOntologyException( - f"Duplicate nested classification '{option.name}' " - f"for option '{self.label}'") - self.options.append(option) - - -@dataclass -class Classification: - """ - - Deprecation Notice: Dropdown classification is deprecated and will be - removed in a future release. Dropdown will also - no longer be able to be created in the Editor on 3/31/2022. - - A classfication to be added to a Project's ontology. The - classification is dependent on the Classification Type. - - To instantiate, the "class_type" and "name" parameters must - be passed in. - - The "options" parameter holds a list of Option objects. This is not - necessary for some Classification types, such as TEXT. To see which - types require options, look at the "_REQUIRES_OPTIONS" class variable. - - Example(s): - classification = Classification( - class_type = Classification.Type.TEXT, - name = "Classification Example") - - classification_two = Classification( - class_type = Classification.Type.RADIO, - name = "Second Example") - classification_two.add_option(Option( - value = "Option Example")) - - Attributes: - class_type: (Classification.Type) - name: (str) - instructions: (str) - required: (bool) - options: (list) - schema_id: (str) - feature_schema_id: (str) - """ - - class Type(Enum): - TEXT = "text" - CHECKLIST = "checklist" - RADIO = "radio" - DROPDOWN = "dropdown" - - class Scope(Enum): - GLOBAL = "global" - INDEX = "index" - - _REQUIRES_OPTIONS = {Type.CHECKLIST, Type.RADIO, Type.DROPDOWN} - - class_type: Type - name: Optional[str] = None - instructions: Optional[str] = None - required: bool = False - options: List[Option] = field(default_factory=list) - schema_id: Optional[str] = None - feature_schema_id: Optional[str] = None - scope: Scope = None - - def __post_init__(self): - if self.class_type == Classification.Type.DROPDOWN: - warnings.warn( - "Dropdown classification is deprecated and will be " - "removed in a future release. Dropdown will also " - "no longer be able to be created in the Editor on 3/31/2022.") - - if self.name is None: - msg = ( - "When creating the Classification feature, please use “name” " - "for the classification schema name, which will be used when " - "creating annotation payload for Model-Assisted Labeling " - "Import and Label Import. “instructions” is no longer " - "supported to specify classification schema name.") - if self.instructions is not None: - self.name = self.instructions - warnings.warn(msg) - else: - raise ValueError(msg) - else: - if self.instructions is None: - self.instructions = self.name - - @classmethod - def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: - return cls(class_type=cls.Type(dictionary["type"]), - name=dictionary["name"], - instructions=dictionary["instructions"], - required=dictionary.get("required", False), - options=[Option.from_dict(o) for o in dictionary["options"]], - schema_id=dictionary.get("schemaNodeId", None), - feature_schema_id=dictionary.get("featureSchemaId", None), - scope=cls.Scope(dictionary.get("scope", cls.Scope.GLOBAL))) - - def asdict(self, is_subclass: bool = False) -> Dict[str, Any]: - if self.class_type in self._REQUIRES_OPTIONS \ - and len(self.options) < 1: - raise InconsistentOntologyException( - f"Classification '{self.name}' requires options.") - classification = { - "type": self.class_type.value, - "instructions": self.instructions, - "name": self.name, - "required": self.required, - "options": [o.asdict() for o in self.options], - "schemaNodeId": self.schema_id, - "featureSchemaId": self.feature_schema_id - } - if is_subclass: - return classification - classification[ - "scope"] = self.scope.value if self.scope is not None else self.Scope.GLOBAL.value - return classification - - def add_option(self, option: Option) -> None: - if option.value in (o.value for o in self.options): - raise InconsistentOntologyException( - f"Duplicate option '{option.value}' " - f"for classification '{self.name}'.") - self.options.append(option) - - -@dataclass -class Tool: - """ - A tool to be added to a Project's ontology. The tool is - dependent on the Tool Type. - - To instantiate, the "tool" and "name" parameters must - be passed in. - - The "classifications" parameter holds a list of Classification objects. - This can be used to add nested classifications to a tool. - - Example(s): - tool = Tool( - tool = Tool.Type.LINE, - name = "Tool example") - classification = Classification( - class_type = Classification.Type.TEXT, - instructions = "Classification Example") - tool.add_classification(classification) - - Attributes: - tool: (Tool.Type) - name: (str) - required: (bool) - color: (str) - classifications: (list) - schema_id: (str) - feature_schema_id: (str) - """ - - class Type(Enum): - POLYGON = "polygon" - SEGMENTATION = "superpixel" - RASTER_SEGMENTATION = "raster-segmentation" - POINT = "point" - BBOX = "rectangle" - LINE = "line" - NER = "named-entity" - RELATIONSHIP = "edge" - - tool: Type - name: str - required: bool = False - color: Optional[str] = None - classifications: List[Classification] = field(default_factory=list) - schema_id: Optional[str] = None - feature_schema_id: Optional[str] = None - - @classmethod - def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: - return cls(name=dictionary['name'], - schema_id=dictionary.get("schemaNodeId", None), - feature_schema_id=dictionary.get("featureSchemaId", None), - required=dictionary.get("required", False), - tool=cls.Type(dictionary["tool"]), - classifications=[ - Classification.from_dict(c) - for c in dictionary["classifications"] - ], - color=dictionary["color"]) - - def asdict(self) -> Dict[str, Any]: - return { - "tool": self.tool.value, - "name": self.name, - "required": self.required, - "color": self.color, - "classifications": [ - c.asdict(is_subclass=True) for c in self.classifications - ], - "schemaNodeId": self.schema_id, - "featureSchemaId": self.feature_schema_id - } - - def add_classification(self, classification: Classification) -> None: - if classification.name in (c.name for c in self.classifications): - raise InconsistentOntologyException( - f"Duplicate nested classification '{classification.name}' " - f"for tool '{self.name}'") - self.classifications.append(classification) - - -class Ontology(DbObject): - """An ontology specifies which tools and classifications are available - to a project. This is read only for now. - Attributes: - name (str) - description (str) - updated_at (datetime) - created_at (datetime) - normalized (json) - object_schema_count (int) - classification_schema_count (int) - projects (Relationship): `ToMany` relationship to Project - created_by (Relationship): `ToOne` relationship to User - """ - - name = Field.String("name") - description = Field.String("description") - updated_at = Field.DateTime("updated_at") - created_at = Field.DateTime("created_at") - normalized = Field.Json("normalized") - object_schema_count = Field.Int("object_schema_count") - classification_schema_count = Field.Int("classification_schema_count") - - projects = Relationship.ToMany("Project", True) - created_by = Relationship.ToOne("User", False, "created_by") - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self._tools: Optional[List[Tool]] = None - self._classifications: Optional[List[Classification]] = None - - def tools(self) -> List[Tool]: - """Get list of tools (AKA objects) in an Ontology.""" - if self._tools is None: - self._tools = [ - Tool.from_dict(tool) for tool in self.normalized['tools'] - ] - return self._tools - - def classifications(self) -> List[Classification]: - """Get list of classifications in an Ontology.""" - if self._classifications is None: - self._classifications = [ - Classification.from_dict(classification) - for classification in self.normalized['classifications'] - ] - return self._classifications - - -@dataclass -class OntologyBuilder: - """ - A class to help create an ontology for a Project. This should be used - for making Project ontologies from scratch. OntologyBuilder can also - pull from an already existing Project's ontology. - - There are no required instantiation arguments. - - To create an ontology, use the asdict() method after fully building your - ontology within this class, and inserting it into project.setup() as the - "labeling_frontend_options" parameter. - - Example: - builder = OntologyBuilder() - ... - frontend = list(client.get_labeling_frontends())[0] - project.setup(frontend, builder.asdict()) - - attributes: - tools: (list) - classifications: (list) - - - """ - tools: List[Tool] = field(default_factory=list) - classifications: List[Classification] = field(default_factory=list) - - @classmethod - def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: - return cls(tools=[Tool.from_dict(t) for t in dictionary["tools"]], - classifications=[ - Classification.from_dict(c) - for c in dictionary["classifications"] - ]) - - def asdict(self) -> Dict[str, Any]: - self._update_colors() - return { - "tools": [t.asdict() for t in self.tools], - "classifications": [c.asdict() for c in self.classifications] - } - - def _update_colors(self): - num_tools = len(self.tools) - - for index in range(num_tools): - hsv_color = (index * 1 / num_tools, 1, 1) - rgb_color = tuple( - int(255 * x) for x in colorsys.hsv_to_rgb(*hsv_color)) - if self.tools[index].color is None: - self.tools[index].color = '#%02x%02x%02x' % rgb_color - - @classmethod - def from_project(cls, project: "project.Project") -> "OntologyBuilder": - ontology = project.ontology().normalized - return cls.from_dict(ontology) - - @classmethod - def from_ontology(cls, ontology: Ontology) -> "OntologyBuilder": - return cls.from_dict(ontology.normalized) - - def add_tool(self, tool: Tool) -> None: - if tool.name in (t.name for t in self.tools): - raise InconsistentOntologyException( - f"Duplicate tool name '{tool.name}'. ") - self.tools.append(tool) - - def add_classification(self, classification: Classification) -> None: - if classification.name in (c.name for c in self.classifications): - raise InconsistentOntologyException( - f"Duplicate classification name '{classification.name}'. ") - self.classifications.append(classification) - ----- -labelbox/schema/queue_mode.py -from enum import Enum - - -class QueueMode(str, Enum): - Batch = "BATCH" - Dataset = "DATA_SET" - - @classmethod - def _missing_(cls, value): - # Parses the deprecated "CATALOG" value back to QueueMode.Batch. - if value == "CATALOG": - return QueueMode.Batch - ----- -labelbox/schema/create_batches_task.py -import json -from typing import TYPE_CHECKING, Callable, List, Optional, Dict, Any - -from labelbox.orm.model import Entity - -if TYPE_CHECKING: - from labelbox import User - - def lru_cache() -> Callable[..., Callable[..., Dict[str, Any]]]: - pass -else: - from functools import lru_cache - - -class CreateBatchesTask: - - def __init__(self, client, project_id: str, batch_ids: List[str], - task_ids: List[str]): - self.client = client - self.project_id = project_id - self.batches = batch_ids - self.tasks = [ - Entity.Task.get_task(self.client, task_id) for task_id in task_ids - ] - - def wait_till_done(self, timeout_seconds: int = 300) -> None: - """ - Waits for the task to complete. - - Args: - timeout_seconds: the number of seconds to wait before timing out - - Returns: None - """ - - for task in self.tasks: - task.wait_till_done(timeout_seconds) - - def errors(self) -> Optional[Dict[str, Any]]: - """ - Returns the errors from the task, if any. - - Returns: a dictionary of errors, keyed by task id - """ - - errors = {} - for task in self.tasks: - if task.status == "FAILED": - errors[task.uid] = json.loads(task.result_url) - - if len(errors) == 0: - return None - - return errors - - @lru_cache() - def result(self): - """ - Returns the batches created by the task. - - Returns: the list of batches created by the task - """ - - return [ - self.client.get_batch(self.project_id, batch_id) - for batch_id in self.batches - ] - ----- -labelbox/schema/resource_tag.py -from labelbox.orm.db_object import DbObject, Updateable -from labelbox.orm.model import Field, Relationship - - -class ResourceTag(DbObject, Updateable): - """ Resource tag to label and identify your labelbox resources easier. - - Attributes: - text (str) - color (str) - - project_resource_tag (Relationship): `ToMany` relationship to ProjectResourceTag - """ - - text = Field.String("text") - color = Field.String("color") - ----- -labelbox/schema/identifiable.py -from abc import ABC -from typing import Union - -from labelbox.schema.id_type import IdType - - -class Identifiable(ABC): - """ - Base class for any object representing a unique identifier. - """ - - def __init__(self, key: str, id_type: IdType): - self._key = key - self._id_type = id_type - - @property - def key(self): - return self._key - - @property - def id_type(self): - return self._id_type - - def __eq__(self, other): - return other.key == self.key and other.id_type == self.id_type - - def __hash__(self): - return hash((self.key, self.id_type)) - - def __str__(self): - return f"{self.id_type}:{self.key}" - - -class UniqueId(Identifiable): - """ - Represents a unique, internally generated id. - """ - - def __init__(self, key: str): - super().__init__(key, IdType.DataRowId) - - -class GlobalKey(Identifiable): - """ - Represents a user generated id. - """ - - def __init__(self, key: str): - super().__init__(key, IdType.GlobalKey) - - -DataRowIdentifier = Union[UniqueId, GlobalKey] - ----- -labelbox/schema/iam_integration.py -from dataclasses import dataclass - -from labelbox.utils import snake_case -from labelbox.orm.db_object import DbObject -from labelbox.orm.model import Field - - -@dataclass -class AwsIamIntegrationSettings: - role_arn: str - - -@dataclass -class GcpIamIntegrationSettings: - service_account_email_id: str - read_bucket: str - - -class IAMIntegration(DbObject): - """ Represents an IAM integration for delegated access - - Attributes: - name (str) - updated_at (datetime) - created_at (datetime) - provider (str) - valid (bool) - last_valid_at (datetime) - is_org_default (boolean) - - """ - - def __init__(self, client, data): - settings = data.pop('settings', None) - if settings is not None: - type_name = settings.pop('__typename') - settings = {snake_case(k): v for k, v in settings.items()} - if type_name == "GcpIamIntegrationSettings": - self.settings = GcpIamIntegrationSettings(**settings) - elif type_name == "AwsIamIntegrationSettings": - self.settings = AwsIamIntegrationSettings(**settings) - else: - self.settings = None - else: - self.settings = None - super().__init__(client, data) - - _DEFAULT = "DEFAULT" - - name = Field.String("name") - created_at = Field.DateTime("created_at") - updated_at = Field.DateTime("updated_at") - provider = Field.String("provider") - valid = Field.Boolean("valid") - last_valid_at = Field.DateTime("last_valid_at") - is_org_default = Field.Boolean("is_org_default") - ----- -labelbox/schema/media_type.py -from enum import Enum - - -class MediaType(Enum): - Audio = "AUDIO" - Conversational = "CONVERSATIONAL" - Dicom = "DICOM" - Document = "PDF" - Geospatial_Tile = "TMS_GEO" - Html = "HTML" - Image = "IMAGE" - Json = "JSON" - LLMPromptCreation = "LLM_PROMPT_CREATION" - LLMPromptResponseCreation = "LLM_PROMPT_RESPONSE_CREATION" - Pdf = "PDF" - Simple_Tile = "TMS_SIMPLE" - Text = "TEXT" - Tms_Geo = "TMS_GEO" - Tms_Simple = "TMS_SIMPLE" - Unknown = "UNKNOWN" - Unsupported = "UNSUPPORTED" - Video = "VIDEO" - - @classmethod - def _missing_(cls, name): - """Handle missing null data types for projects - created without setting allowedMediaType - Handle upper case names for compatibility with - the GraphQL""" - - if name is None: - return cls.Unknown - - for member in cls.__members__: - if member.name == name.upper(): - return member - - @classmethod - def is_supported(cls, value): - return isinstance(value, - cls) and value not in [cls.Unknown, cls.Unsupported] - - @classmethod - def get_supported_members(cls): - return [ - item for item in cls.__members__ - if item not in ["Unknown", "Unsupported"] - ] - - -def get_media_type_validation_error(media_type): - return TypeError(f"{media_type} is not a valid media type. Use" - f" any of {MediaType.get_supported_members()}" - " from MediaType. Example: MediaType.Image.") - ----- -labelbox/schema/send_to_annotate_params.py -import sys - -from typing import Optional, Dict - -from labelbox.schema.conflict_resolution_strategy import ConflictResolutionStrategy - -if sys.version_info >= (3, 8): - from typing import TypedDict -else: - from typing_extensions import TypedDict - - -class SendToAnnotateFromCatalogParams(TypedDict): - """ - Extra parameters for sending data rows to a project through catalog. At least one of source_model_run_id or - source_project_id must be provided. - - :param source_model_run_id: Optional[str] - The model run to use for predictions. Defaults to None. - :param predictions_ontology_mapping: Optional[Dict[str, str]] - A mapping of feature schema ids to feature schema - ids. Defaults to an empty dictionary. - :param source_project_id: Optional[str] - The project to use for predictions. Defaults to None. - :param annotations_ontology_mapping: Optional[Dict[str, str]] - A mapping of feature schema ids to feature schema - ids. Defaults to an empty dictionary. - :param exclude_data_rows_in_project: Optional[bool] - Exclude data rows that are already in the project. Defaults - to False. - :param override_existing_annotations_rule: Optional[ConflictResolutionStrategy] - The strategy defining how to - handle conflicts in classifications between the data rows that already exist in the project and incoming - predictions from the source model run or annotations from the source project. Defaults to - ConflictResolutionStrategy.KEEP_EXISTING. - :param batch_priority: Optional[int] - The priority of the batch. Defaults to 5. - """ - - source_model_run_id: Optional[str] - predictions_ontology_mapping: Optional[Dict[str, str]] - source_project_id: Optional[str] - annotations_ontology_mapping: Optional[Dict[str, str]] - exclude_data_rows_in_project: Optional[bool] - override_existing_annotations_rule: Optional[ConflictResolutionStrategy] - batch_priority: Optional[int] - - -class SendToAnnotateFromModelParams(TypedDict): - """ - Extra parameters for sending data rows to a project through a model run. - - :param predictions_ontology_mapping: Dict[str, str] - A mapping of feature schema ids to feature schema ids. - Defaults to an empty dictionary. - :param exclude_data_rows_in_project: Optional[bool] - Exclude data rows that are already in the project. Defaults - to False. - :param override_existing_annotations_rule: Optional[ConflictResolutionStrategy] - The strategy defining how to - handle conflicts in classifications between the data rows that already exist in the project and incoming - predictions from the source model run. Defaults to ConflictResolutionStrategy.KEEP_EXISTING. - :param batch_priority: Optional[int] - The priority of the batch. Defaults to 5. - """ - - predictions_ontology_mapping: Dict[str, str] - exclude_data_rows_in_project: Optional[bool] - override_existing_annotations_rule: Optional[ConflictResolutionStrategy] - batch_priority: Optional[int] - - -def build_annotations_input(project_ontology_mapping: Optional[Dict[str, str]], - source_project_id: str): - return { - "projectId": - source_project_id, - "featureSchemaIdsMapping": - project_ontology_mapping if project_ontology_mapping else {}, - } - - -def build_destination_task_queue_input(task_queue_id: str): - destination_task_queue = { - "type": "id", - "value": task_queue_id - } if task_queue_id else { - "type": "done" - } - return destination_task_queue - - -def build_predictions_input(model_run_ontology_mapping: Optional[Dict[str, - str]], - source_model_run_id: str): - return { - "featureSchemaIdsMapping": - model_run_ontology_mapping if model_run_ontology_mapping else {}, - "modelRunId": - source_model_run_id, - "minConfidence": - 0, - "maxConfidence": - 1 - } - ----- -labelbox/schema/quality_mode.py -from enum import Enum - - -class QualityMode(str, Enum): - Benchmark = "BENCHMARK" - Consensus = "CONSENSUS" - - -BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS = 1 -BENCHMARK_AUTO_AUDIT_PERCENTAGE = 1 -CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS = 3 -CONSENSUS_AUTO_AUDIT_PERCENTAGE = 0 - ----- -labelbox/schema/serialization.py -from typing import cast, Any, Dict, Generator, List, TYPE_CHECKING, Union - -if TYPE_CHECKING: - from labelbox.types import Label - - -def serialize_labels( - objects: Union[List[Dict[str, Any]], - List["Label"]]) -> List[Dict[str, Any]]: - """ - Checks if objects are of type Labels and serializes labels for annotation import. Serialization depends the labelbox[data] package, therefore NDJsonConverter is only loaded if using `Label` objects instead of `dict` objects. - """ - if len(objects) == 0: - return [] - - is_label_type = not isinstance(objects[0], Dict) - if is_label_type: - # If a Label object exists, labelbox[data] is already installed, so no error checking is needed. - from labelbox.data.serialization import NDJsonConverter - labels = cast(List["Label"], objects) - return list(NDJsonConverter.serialize(labels)) - - return cast(List[Dict[str, Any]], objects) - ----- -labelbox/schema/project.py -import json -import logging -from string import Template -import time -import warnings -from collections import namedtuple -from datetime import datetime, timezone -from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, TypeVar, Union, overload -from urllib.parse import urlparse - -import requests - -from labelbox import parser -from labelbox import utils -from labelbox.exceptions import ( - InvalidQueryError, - LabelboxError, - ProcessingWaitTimeout, - ResourceConflict, -) -from labelbox.orm import query -from labelbox.orm.db_object import DbObject, Deletable, Updateable, experimental -from labelbox.orm.model import Entity, Field, Relationship -from labelbox.pagination import PaginatedCollection -from labelbox.schema.consensus_settings import ConsensusSettings -from labelbox.schema.create_batches_task import CreateBatchesTask -from labelbox.schema.data_row import DataRow -from labelbox.schema.export_filters import ProjectExportFilters, validate_datetime, build_filters -from labelbox.schema.export_params import ProjectExportParams -from labelbox.schema.export_task import ExportTask -from labelbox.schema.id_type import IdType -from labelbox.schema.identifiable import DataRowIdentifier, GlobalKey, UniqueId -from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds -from labelbox.schema.media_type import MediaType -from labelbox.schema.queue_mode import QueueMode -from labelbox.schema.resource_tag import ResourceTag -from labelbox.schema.task import Task -from labelbox.schema.task_queue import TaskQueue - -if TYPE_CHECKING: - from labelbox import BulkImportRequest - -try: - from labelbox.data.serialization import LBV1Converter -except ImportError: - pass - -DataRowPriority = int -LabelingParameterOverrideInput = Tuple[Union[DataRow, DataRowIdentifier], - DataRowPriority] - -logger = logging.getLogger(__name__) - - -def validate_labeling_parameter_overrides( - data: List[LabelingParameterOverrideInput]) -> None: - for idx, row in enumerate(data): - if len(row) < 2: - raise TypeError( - f"Data must be a list of tuples each containing two elements: a DataRow or a DataRowIdentifier and priority (int). Found {len(row)} items. Index: {idx}" - ) - data_row_identifier = row[0] - priority = row[1] - valid_types = (Entity.DataRow, UniqueId, GlobalKey) - if not isinstance(data_row_identifier, valid_types): - raise TypeError( - f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found {type(data_row_identifier)} for data_row_identifier {data_row_identifier}" - ) - - if not isinstance(priority, int): - if isinstance(data_row_identifier, Entity.DataRow): - id = data_row_identifier.uid - else: - id = data_row_identifier - raise TypeError( - f"Priority must be an int. Found {type(priority)} for data_row_identifier {id}" - ) - - -class Project(DbObject, Updateable, Deletable): - """ A Project is a container that includes a labeling frontend, an ontology, - datasets and labels. - - Attributes: - name (str) - description (str) - updated_at (datetime) - created_at (datetime) - setup_complete (datetime) - last_activity_time (datetime) - queue_mode (string) - auto_audit_number_of_labels (int) - auto_audit_percentage (float) - - created_by (Relationship): `ToOne` relationship to User - organization (Relationship): `ToOne` relationship to Organization - labeling_frontend (Relationship): `ToOne` relationship to LabelingFrontend - labeling_frontend_options (Relationship): `ToMany` relationship to LabelingFrontendOptions - labeling_parameter_overrides (Relationship): `ToMany` relationship to LabelingParameterOverride - webhooks (Relationship): `ToMany` relationship to Webhook - benchmarks (Relationship): `ToMany` relationship to Benchmark - ontology (Relationship): `ToOne` relationship to Ontology - task_queues (Relationship): `ToMany` relationship to TaskQueue - """ - - name = Field.String("name") - description = Field.String("description") - updated_at = Field.DateTime("updated_at") - created_at = Field.DateTime("created_at") - setup_complete = Field.DateTime("setup_complete") - last_activity_time = Field.DateTime("last_activity_time") - queue_mode = Field.Enum(QueueMode, "queue_mode") - auto_audit_number_of_labels = Field.Int("auto_audit_number_of_labels") - auto_audit_percentage = Field.Float("auto_audit_percentage") - # Bind data_type and allowedMediaTYpe using the GraphQL type MediaType - media_type = Field.Enum(MediaType, "media_type", "allowedMediaType") - - # Relationships - created_by = Relationship.ToOne("User", False, "created_by") - organization = Relationship.ToOne("Organization", False) - labeling_frontend = Relationship.ToOne("LabelingFrontend") - labeling_frontend_options = Relationship.ToMany( - "LabelingFrontendOptions", False, "labeling_frontend_options") - labeling_parameter_overrides = Relationship.ToMany( - "LabelingParameterOverride", False, "labeling_parameter_overrides") - webhooks = Relationship.ToMany("Webhook", False) - benchmarks = Relationship.ToMany("Benchmark", False) - ontology = Relationship.ToOne("Ontology", True) - - # - _wait_processing_max_seconds = 3600 - - def update(self, **kwargs): - """ Updates this project with the specified attributes - - Args: - kwargs: a dictionary containing attributes to be upserted - - Note that the queue_mode cannot be changed after a project has been created. - - Additionally, the quality setting cannot be changed after a project has been created. The quality mode - for a project is inferred through the following attributes: - - Benchmark: - auto_audit_number_of_labels = 1 and auto_audit_percentage = 1.0 - - Consensus: - auto_audit_number_of_labels > 1 or auto_audit_percentage <= 1.0 - - Attempting to switch between benchmark and consensus modes is an invalid operation and will result - in an error. - """ - - media_type = kwargs.get("media_type") - if media_type: - if MediaType.is_supported(media_type): - kwargs["media_type"] = media_type.value - else: - raise TypeError(f"{media_type} is not a valid media type. Use" - f" any of {MediaType.get_supported_members()}" - " from MediaType. Example: MediaType.Image.") - - return super().update(**kwargs) - - def members(self) -> PaginatedCollection: - """ Fetch all current members for this project - - Returns: - A `PaginatedCollection` of `ProjectMember`s - - """ - id_param = "projectId" - query_str = """query ProjectMemberOverviewPyApi($%s: ID!) { - project(where: {id : $%s}) { id members(skip: %%d first: %%d){ id user { %s } role { id name } accessFrom } - } - }""" % (id_param, id_param, query.results_query_part(Entity.User)) - return PaginatedCollection(self.client, query_str, - {id_param: str(self.uid)}, - ["project", "members"], ProjectMember) - - def update_project_resource_tags( - self, resource_tag_ids: List[str]) -> List[ResourceTag]: - """ Creates project resource tags - - Args: - resource_tag_ids - Returns: - a list of ResourceTag ids that was created. - """ - project_id_param = "projectId" - tag_ids_param = "resourceTagIds" - - query_str = """mutation UpdateProjectResourceTagsPyApi($%s:ID!,$%s:[String!]) { - project(where:{id:$%s}){updateProjectResourceTags(input:{%s:$%s}){%s}}}""" % ( - project_id_param, tag_ids_param, project_id_param, tag_ids_param, - tag_ids_param, query.results_query_part(ResourceTag)) - - res = self.client.execute(query_str, { - project_id_param: self.uid, - tag_ids_param: resource_tag_ids - }) - - return [ - ResourceTag(self.client, tag) - for tag in res["project"]["updateProjectResourceTags"] - ] - - def get_resource_tags(self) -> List[ResourceTag]: - """ - Returns tags for a project - """ - query_str = """query GetProjectResourceTagsPyApi($projectId: ID!) { - project(where: {id: $projectId}) { - name - resourceTags {%s} - } - }""" % (query.results_query_part(ResourceTag)) - - results = self.client.execute( - query_str, {"projectId": self.uid})['project']['resourceTags'] - - return [ResourceTag(self.client, tag) for tag in results] - - def labels(self, datasets=None, order_by=None) -> PaginatedCollection: - """ Custom relationship expansion method to support limited filtering. - - Args: - datasets (iterable of Dataset): Optional collection of Datasets - whose Labels are sought. If not provided, all Labels in - this Project are returned. - order_by (None or (Field, Field.Order)): Ordering clause. - """ - Label = Entity.Label - - if datasets is not None: - where = " where:{dataRow: {dataset: {id_in: [%s]}}}" % ", ".join( - '"%s"' % dataset.uid for dataset in datasets) - else: - where = "" - - if order_by is not None: - query.check_order_by_clause(Label, order_by) - order_by_str = "orderBy: %s_%s" % (order_by[0].graphql_name, - order_by[1].name.upper()) - else: - order_by_str = "" - - id_param = "projectId" - query_str = """query GetProjectLabelsPyApi($%s: ID!) - {project (where: {id: $%s}) - {labels (skip: %%d first: %%d %s %s) {%s}}}""" % ( - id_param, id_param, where, order_by_str, - query.results_query_part(Label)) - - return PaginatedCollection(self.client, query_str, {id_param: self.uid}, - ["project", "labels"], Label) - - def export_queued_data_rows( - self, - timeout_seconds=120, - include_metadata: bool = False) -> List[Dict[str, str]]: - """ Returns all data rows that are currently enqueued for this project. - - Args: - timeout_seconds (float): Max waiting time, in seconds. - include_metadata (bool): True to return related DataRow metadata - Returns: - Data row fields for all data rows in the queue as json - Raises: - LabelboxError: if the export fails or is unable to download within the specified time. - """ - warnings.warn( - "You are currently utilizing exports v1 for this action, which will be deprecated after April 30th, 2024. We recommend transitioning to exports v2. To view export v2 details, visit our docs: https://docs.labelbox.com/reference/label-export", - DeprecationWarning) - id_param = "projectId" - metadata_param = "includeMetadataInput" - query_str = """mutation GetQueuedDataRowsExportUrlPyApi($%s: ID!, $%s: Boolean!) - {exportQueuedDataRows(data:{projectId: $%s , includeMetadataInput: $%s}) {downloadUrl createdAt status} } - """ % (id_param, metadata_param, id_param, metadata_param) - sleep_time = 2 - start_time = time.time() - while True: - res = self.client.execute(query_str, { - id_param: self.uid, - metadata_param: include_metadata - }) - res = res["exportQueuedDataRows"] - if res["status"] == "COMPLETE": - download_url = res["downloadUrl"] - response = requests.get(download_url) - response.raise_for_status() - return parser.loads(response.text) - elif res["status"] == "FAILED": - raise LabelboxError("Data row export failed.") - - current_time = time.time() - if current_time - start_time > timeout_seconds: - raise LabelboxError( - f"Unable to export data rows within {timeout_seconds} seconds." - ) - - logger.debug( - "Project '%s' queued data row export, waiting for server...", - self.uid) - time.sleep(sleep_time) - - def label_generator(self, timeout_seconds=600, **kwargs): - """ - Download text and image annotations, or video annotations. - - For a mixture of text/image and video, use project.export_labels() - - Returns: - LabelGenerator for accessing labels - """ - _check_converter_import() - json_data = self.export_labels(download=True, - timeout_seconds=timeout_seconds, - **kwargs) - - # assert that the instance this would fail is only if timeout runs out - assert isinstance( - json_data, - List), "Unable to successfully get labels. Please try again" - - if json_data is None: - raise TimeoutError( - f"Unable to download labels in {timeout_seconds} seconds." - "Please try again or contact support if the issue persists.") - - is_video = [ - "frames" in row["Label"] - for row in json_data - if row["Label"] and not row["Skipped"] - ] - - if len(is_video) and not all(is_video) and any(is_video): - raise ValueError( - "Found mixed data types of video and text/image. " - "Use project.export_labels() to export projects with mixed data types. " - ) - if len(is_video) and all(is_video): - # Filter skipped labels to avoid inference errors - json_data = [ - label for label in self.export_labels(download=True) - if not label["Skipped"] - ] - - return LBV1Converter.deserialize_video(json_data, self.client) - - return LBV1Converter.deserialize(json_data) - - def export_labels(self, - download=False, - timeout_seconds=1800, - **kwargs) -> Optional[Union[str, List[Dict[Any, Any]]]]: - """ Calls the server-side Label exporting that generates a JSON - payload, and returns the URL to that payload. - - Will only generate a new URL at a max frequency of 30 min. - - Args: - download (bool): Returns the url if False - timeout_seconds (float): Max waiting time, in seconds. - start (str): Earliest date for labels, formatted "YYYY-MM-DD" or "YYYY-MM-DD hh:mm:ss" - end (str): Latest date for labels, formatted "YYYY-MM-DD" or "YYYY-MM-DD hh:mm:ss" - last_activity_start (str): Will include all labels that have had any updates to - data rows, issues, comments, metadata, or reviews since this timestamp. - formatted "YYYY-MM-DD" or "YYYY-MM-DD hh:mm:ss" - last_activity_end (str): Will include all labels that do not have any updates to - data rows, issues, comments, metadata, or reviews after this timestamp. - formatted "YYYY-MM-DD" or "YYYY-MM-DD hh:mm:ss" - - Returns: - URL of the data file with this Project's labels. If the server didn't - generate during the `timeout_seconds` period, None is returned. - """ - warnings.warn( - "You are currently utilizing exports v1 for this action, which will be deprecated after April 30th, 2024. We recommend transitioning to exports v2. To view export v2 details, visit our docs: https://docs.labelbox.com/reference/label-export", - DeprecationWarning) - - def _string_from_dict(dictionary: dict, value_with_quotes=False) -> str: - """Returns a concatenated string of the dictionary's keys and values - - The string will be formatted as {key}: 'value' for each key. Value will be inclusive of - quotations while key will not. This can be toggled with `value_with_quotes`""" - - quote = "\"" if value_with_quotes else "" - return ",".join([ - f"""{c}: {quote}{dictionary.get(c)}{quote}""" - for c in dictionary - if dictionary.get(c) - ]) - - sleep_time = 2 - id_param = "projectId" - filter_param = "" - filter_param_dict = {} - - if "start" in kwargs or "end" in kwargs: - created_at_dict = { - "start": kwargs.get("start", ""), - "end": kwargs.get("end", "") - } - [validate_datetime(date) for date in created_at_dict.values()] - filter_param_dict["labelCreatedAt"] = "{%s}" % _string_from_dict( - created_at_dict, value_with_quotes=True) - - if "last_activity_start" in kwargs or "last_activity_end" in kwargs: - last_activity_start = kwargs.get('last_activity_start') - last_activity_end = kwargs.get('last_activity_end') - - if last_activity_start: - validate_datetime(str(last_activity_start)) - if last_activity_end: - validate_datetime(str(last_activity_end)) - - filter_param_dict["lastActivityAt"] = "{%s}" % _string_from_dict( - { - "start": last_activity_start, - "end": last_activity_end - }, - value_with_quotes=True) - - if filter_param_dict: - filter_param = """, filters: {%s }""" % (_string_from_dict( - filter_param_dict, value_with_quotes=False)) - - query_str = """mutation GetLabelExportUrlPyApi($%s: ID!) - {exportLabels(data:{projectId: $%s%s}) {downloadUrl createdAt shouldPoll} } - """ % (id_param, id_param, filter_param) - - start_time = time.time() - - while True: - res = self.client.execute(query_str, {id_param: self.uid}) - res = res["exportLabels"] - if not res["shouldPoll"] and res["downloadUrl"] is not None: - url = res['downloadUrl'] - if not download: - return url - else: - response = requests.get(url) - response.raise_for_status() - return response.json() - - current_time = time.time() - if current_time - start_time > timeout_seconds: - return None - - logger.debug("Project '%s' label export, waiting for server...", - self.uid) - time.sleep(sleep_time) - - @experimental - def export( - self, - task_name: Optional[str] = None, - filters: Optional[ProjectExportFilters] = None, - params: Optional[ProjectExportParams] = None, - ) -> ExportTask: - """ - Creates a project export task with the given params and returns the task. - - >>> task = project.export( - >>> filters={ - >>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - >>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - >>> "data_row_ids": [DATA_ROW_ID_1, DATA_ROW_ID_2, ...] # or global_keys: [DATA_ROW_GLOBAL_KEY_1, DATA_ROW_GLOBAL_KEY_2, ...] - >>> "batch_ids": [BATCH_ID_1, BATCH_ID_2, ...] - >>> }, - >>> params={ - >>> "performance_details": False, - >>> "label_details": True - >>> }) - >>> task.wait_till_done() - >>> task.result - """ - task = self._export(task_name, filters, params, streamable=True) - return ExportTask(task) - - def export_v2( - self, - task_name: Optional[str] = None, - filters: Optional[ProjectExportFilters] = None, - params: Optional[ProjectExportParams] = None, - ) -> Task: - """ - Creates a project export task with the given params and returns the task. - - For more information visit: https://docs.labelbox.com/docs/exports-v2#export-from-a-project-python-sdk - - >>> task = project.export_v2( - >>> filters={ - >>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - >>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], - >>> "data_row_ids": [DATA_ROW_ID_1, DATA_ROW_ID_2, ...] # or global_keys: [DATA_ROW_GLOBAL_KEY_1, DATA_ROW_GLOBAL_KEY_2, ...] - >>> "batch_ids": [BATCH_ID_1, BATCH_ID_2, ...] - >>> }, - >>> params={ - >>> "performance_details": False, - >>> "label_details": True - >>> }) - >>> task.wait_till_done() - >>> task.result - """ - return self._export(task_name, filters, params) - - def _export( - self, - task_name: Optional[str] = None, - filters: Optional[ProjectExportFilters] = None, - params: Optional[ProjectExportParams] = None, - streamable: bool = False, - ) -> Task: - _params = params or ProjectExportParams({ - "attachments": False, - "metadata_fields": False, - "data_row_details": False, - "project_details": False, - "performance_details": False, - "label_details": False, - "media_type_override": None, - "interpolated_frames": False, - }) - - _filters = filters or ProjectExportFilters({ - "last_activity_at": None, - "label_created_at": None, - "data_row_ids": None, - "global_keys": None, - "batch_ids": None, - "workflow_status": None - }) - - mutation_name = "exportDataRowsInProject" - create_task_query_str = ( - f"mutation {mutation_name}PyApi" - f"($input: ExportDataRowsInProjectInput!)" - f"{{{mutation_name}(input: $input){{taskId}}}}") - - media_type_override = _params.get('media_type_override', None) - query_params: Dict[str, Any] = { - "input": { - "taskName": task_name, - "filters": { - "projectId": self.uid, - "searchQuery": { - "scope": None, - "query": [], - } - }, - "params": { - "mediaTypeOverride": - media_type_override.value - if media_type_override is not None else None, - "includeAttachments": - _params.get('attachments', False), - "includeMetadata": - _params.get('metadata_fields', False), - "includeDataRowDetails": - _params.get('data_row_details', False), - "includeProjectDetails": - _params.get('project_details', False), - "includePerformanceDetails": - _params.get('performance_details', False), - "includeLabelDetails": - _params.get('label_details', False), - "includeInterpolatedFrames": - _params.get('interpolated_frames', False), - }, - "streamable": streamable, - } - } - - search_query = build_filters(self.client, _filters) - query_params["input"]["filters"]["searchQuery"]["query"] = search_query - - res = self.client.execute(create_task_query_str, - query_params, - error_log_key="errors") - res = res[mutation_name] - task_id = res["taskId"] - return Task.get_task(self.client, task_id) - - def export_issues(self, status=None) -> str: - """ Calls the server-side Issues exporting that - returns the URL to that payload. - - Args: - status (string): valid values: Open, Resolved - Returns: - URL of the data file with this Project's issues. - """ - id_param = "projectId" - status_param = "status" - query_str = """query GetProjectIssuesExportPyApi($%s: ID!, $%s: IssueStatus) { - project(where: { id: $%s }) { - issueExportUrl(where: { status: $%s }) - } - }""" % (id_param, status_param, id_param, status_param) - - valid_statuses = {None, "Open", "Resolved"} - - if status not in valid_statuses: - raise ValueError("status must be in {}. Found {}".format( - valid_statuses, status)) - - res = self.client.execute(query_str, { - id_param: self.uid, - status_param: status - }) - - res = res['project'] - - logger.debug("Project '%s' issues export, link generated", self.uid) - - return res.get('issueExportUrl') - - def upsert_instructions(self, instructions_file: str) -> None: - """ - * Uploads instructions to the UI. Running more than once will replace the instructions - - Args: - instructions_file (str): Path to a local file. - * Must be a pdf or html file - - Raises: - ValueError: - * project must be setup - * instructions file must have a ".pdf" or ".html" extension - """ - - if self.setup_complete is None: - raise ValueError( - "Cannot attach instructions to a project that has not been set up." - ) - - frontend = self.labeling_frontend() - - if frontend.name != "Editor": - logger.warning( - f"This function has only been tested to work with the Editor front end. Found %s", - frontend.name) - - supported_instruction_formats = (".pdf", ".html") - if not instructions_file.endswith(supported_instruction_formats): - raise ValueError( - f"instructions_file must be a pdf or html file. Found {instructions_file}" - ) - - instructions_url = self.client.upload_file(instructions_file) - - query_str = """mutation setprojectinsructionsPyApi($projectId: ID!, $instructions_url: String!) { - setProjectInstructions( - where: {id: $projectId}, - data: {instructionsUrl: $instructions_url} - ) { - id - ontology { - id - options - } - } - }""" - - self.client.execute(query_str, { - 'projectId': self.uid, - 'instructions_url': instructions_url - }) - - def labeler_performance(self) -> PaginatedCollection: - """ Returns the labeler performances for this Project. - - Returns: - A PaginatedCollection of LabelerPerformance objects. - """ - id_param = "projectId" - query_str = """query LabelerPerformancePyApi($%s: ID!) { - project(where: {id: $%s}) { - labelerPerformance(skip: %%d first: %%d) { - count user {%s} secondsPerLabel totalTimeLabeling consensus - averageBenchmarkAgreement lastActivityTime} - }}""" % (id_param, id_param, query.results_query_part(Entity.User)) - - def create_labeler_performance(client, result): - result["user"] = Entity.User(client, result["user"]) - # python isoformat doesn't accept Z as utc timezone - result["lastActivityTime"] = utils.format_iso_from_string( - result["lastActivityTime"].replace('Z', '+00:00')) - return LabelerPerformance(**{ - utils.snake_case(key): value for key, value in result.items() - }) - - return PaginatedCollection(self.client, query_str, {id_param: self.uid}, - ["project", "labelerPerformance"], - create_labeler_performance) - - def review_metrics(self, net_score) -> int: - """ Returns this Project's review metrics. - - Args: - net_score (None or Review.NetScore): Indicates desired metric. - Returns: - int, aggregation count of reviews for given `net_score`. - """ - if net_score not in (None,) + tuple(Entity.Review.NetScore): - raise InvalidQueryError( - "Review metrics net score must be either None " - "or one of Review.NetScore values") - id_param = "projectId" - net_score_literal = "None" if net_score is None else net_score.name - query_str = """query ProjectReviewMetricsPyApi($%s: ID!){ - project(where: {id:$%s}) - {reviewMetrics {labelAggregate(netScore: %s) {count}}} - }""" % (id_param, id_param, net_score_literal) - res = self.client.execute(query_str, {id_param: self.uid}) - return res["project"]["reviewMetrics"]["labelAggregate"]["count"] - - def setup_editor(self, ontology) -> None: - """ - Sets up the project using the Pictor editor. - - Args: - ontology (Ontology): The ontology to attach to the project - """ - if self.labeling_frontend() is not None: - raise ResourceConflict("Editor is already set up.") - - labeling_frontend = next( - self.client.get_labeling_frontends( - where=Entity.LabelingFrontend.name == "Editor")) - self.labeling_frontend.connect(labeling_frontend) - - LFO = Entity.LabelingFrontendOptions - self.client._create( - LFO, { - LFO.project: - self, - LFO.labeling_frontend: - labeling_frontend, - LFO.customization_options: - json.dumps({ - "tools": [], - "classifications": [] - }) - }) - - query_str = """mutation ConnectOntologyPyApi($projectId: ID!, $ontologyId: ID!){ - project(where: {id: $projectId}) {connectOntology(ontologyId: $ontologyId) {id}}}""" - self.client.execute(query_str, { - 'ontologyId': ontology.uid, - 'projectId': self.uid - }) - timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") - self.update(setup_complete=timestamp) - - def setup(self, labeling_frontend, labeling_frontend_options) -> None: - """ Finalizes the Project setup. - - Args: - labeling_frontend (LabelingFrontend): Which UI to use to label the - data. - labeling_frontend_options (dict or str): Labeling frontend options, - a.k.a. project ontology. If given a `dict` it will be converted - to `str` using `json.dumps`. - """ - - if self.labeling_frontend() is not None: - raise ResourceConflict("Editor is already set up.") - - if not isinstance(labeling_frontend_options, str): - labeling_frontend_options = json.dumps(labeling_frontend_options) - - self.labeling_frontend.connect(labeling_frontend) - - LFO = Entity.LabelingFrontendOptions - self.client._create( - LFO, { - LFO.project: self, - LFO.labeling_frontend: labeling_frontend, - LFO.customization_options: labeling_frontend_options - }) - - timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") - self.update(setup_complete=timestamp) - - def create_batch( - self, - name: str, - data_rows: Optional[List[Union[str, DataRow]]] = None, - priority: int = 5, - consensus_settings: Optional[Dict[str, float]] = None, - global_keys: Optional[List[str]] = None, - ): - """ - Creates a new batch for a project. One of `global_keys` or `data_rows` must be provided, but not both. A - maximum of 100,000 data rows can be added to a batch. - - Args: - name: a name for the batch, must be unique within a project - data_rows: Either a list of `DataRows` or Data Row ids. - global_keys: global keys for data rows to add to the batch. - priority: An optional priority for the Data Rows in the Batch. 1 highest -> 5 lowest - consensus_settings: An optional dictionary with consensus settings: {'number_of_labels': 3, - 'coverage_percentage': 0.1} - - Returns: the created batch - """ - - # @TODO: make this automatic? - if self.queue_mode != QueueMode.Batch: - raise ValueError("Project must be in batch mode") - - dr_ids = [] - if data_rows is not None: - for dr in data_rows: - if isinstance(dr, Entity.DataRow): - dr_ids.append(dr.uid) - elif isinstance(dr, str): - dr_ids.append(dr) - else: - raise ValueError( - "`data_rows` must be DataRow ids or DataRow objects") - - if data_rows is not None: - row_count = len(dr_ids) - elif global_keys is not None: - row_count = len(global_keys) - else: - row_count = 0 - - if row_count > 100_000: - raise ValueError( - f"Batch exceeds max size, break into smaller batches") - if not row_count: - raise ValueError("You need at least one data row in a batch") - - self._wait_until_data_rows_are_processed( - dr_ids, global_keys, self._wait_processing_max_seconds) - - if consensus_settings: - consensus_settings = ConsensusSettings(**consensus_settings).dict( - by_alias=True) - - if row_count >= 1_000: - return self._create_batch_async(name, dr_ids, global_keys, priority, - consensus_settings) - else: - return self._create_batch_sync(name, dr_ids, global_keys, priority, - consensus_settings) - - def create_batches( - self, - name_prefix: str, - data_rows: Optional[List[Union[str, DataRow]]] = None, - global_keys: Optional[List[str]] = None, - priority: int = 5, - consensus_settings: Optional[Dict[str, float]] = None, - ) -> CreateBatchesTask: - """ - Creates batches for a project from a list of data rows. One of `global_keys` or `data_rows` must be provided, - but not both. When more than 100k data rows are specified and thus multiple batches are needed, the specific - batch that each data row will be placed in is undefined. - - Batches will be created with the specified name prefix and a unique suffix. The suffix will be a 4-digit - number starting at 0000. For example, if the name prefix is "batch" and 3 batches are created, the names - will be "batch0000", "batch0001", and "batch0002". This method will throw an error if a batch with the same - name already exists. - - Args: - name_prefix: a prefix for the batch names, must be unique within a project - data_rows: Either a list of `DataRows` or Data Row ids. - global_keys: global keys for data rows to add to the batch. - priority: An optional priority for the Data Rows in the Batch. 1 highest -> 5 lowest - consensus_settings: An optional dictionary with consensus settings: {'number_of_labels': 3, - 'coverage_percentage': 0.1} - - Returns: a task for the created batches - """ - - if self.queue_mode != QueueMode.Batch: - raise ValueError("Project must be in batch mode") - - dr_ids = [] - if data_rows is not None: - for dr in data_rows: - if isinstance(dr, Entity.DataRow): - dr_ids.append(dr.uid) - elif isinstance(dr, str): - dr_ids.append(dr) - else: - raise ValueError( - "`data_rows` must be DataRow ids or DataRow objects") - - self._wait_until_data_rows_are_processed( - dr_ids, global_keys, self._wait_processing_max_seconds) - - if consensus_settings: - consensus_settings = ConsensusSettings(**consensus_settings).dict( - by_alias=True) - - method = 'createBatches' - mutation_str = """mutation %sPyApi($projectId: ID!, $input: CreateBatchesInput!) { - project(where: {id: $projectId}) { - %s(input: $input) { - tasks { - batchUuid - taskId - } - } - } - } - """ % (method, method) - - params = { - "projectId": self.uid, - "input": { - "batchNamePrefix": name_prefix, - "dataRowIds": dr_ids, - "globalKeys": global_keys, - "priority": priority, - "consensusSettings": consensus_settings - } - } - - tasks = self.client.execute( - mutation_str, params, experimental=True)["project"][method]["tasks"] - batch_ids = [task["batchUuid"] for task in tasks] - task_ids = [task["taskId"] for task in tasks] - - return CreateBatchesTask(self.client, self.uid, batch_ids, task_ids) - - def create_batches_from_dataset( - self, - name_prefix: str, - dataset_id: str, - priority: int = 5, - consensus_settings: Optional[Dict[str, - float]] = None) -> CreateBatchesTask: - """ - Creates batches for a project from a dataset, selecting only the data rows that are not already added to the - project. When the dataset contains more than 100k data rows and multiple batches are needed, the specific batch - that each data row will be placed in is undefined. Note that data rows may not be immediately available for a - project after being added to a dataset; use the `_wait_until_data_rows_are_processed` method to ensure that - data rows are available before creating batches. - - Batches will be created with the specified name prefix and a unique suffix. The suffix will be a 4-digit - number starting at 0000. For example, if the name prefix is "batch" and 3 batches are created, the names - will be "batch0000", "batch0001", and "batch0002". This method will throw an error if a batch with the same - name already exists. - - Args: - name_prefix: a prefix for the batch names, must be unique within a project - dataset_id: the id of the dataset to create batches from - priority: An optional priority for the Data Rows in the Batch. 1 highest -> 5 lowest - consensus_settings: An optional dictionary with consensus settings: {'number_of_labels': 3, - 'coverage_percentage': 0.1} - - Returns: a task for the created batches - """ - - if self.queue_mode != QueueMode.Batch: - raise ValueError("Project must be in batch mode") - - if consensus_settings: - consensus_settings = ConsensusSettings(**consensus_settings).dict( - by_alias=True) - - print("Creating batches from dataset %s", dataset_id) - - method = 'createBatchesFromDataset' - mutation_str = """mutation %sPyApi($projectId: ID!, $input: CreateBatchesFromDatasetInput!) { - project(where: {id: $projectId}) { - %s(input: $input) { - tasks { - batchUuid - taskId - } - } - } - } - """ % (method, method) - - params = { - "projectId": self.uid, - "input": { - "batchNamePrefix": name_prefix, - "datasetId": dataset_id, - "priority": priority, - "consensusSettings": consensus_settings - } - } - - tasks = self.client.execute( - mutation_str, params, experimental=True)["project"][method]["tasks"] - - batch_ids = [task["batchUuid"] for task in tasks] - task_ids = [task["taskId"] for task in tasks] - - return CreateBatchesTask(self.client, self.uid, batch_ids, task_ids) - - def _create_batch_sync(self, name, dr_ids, global_keys, priority, - consensus_settings): - method = 'createBatchV2' - query_str = """mutation %sPyApi($projectId: ID!, $batchInput: CreateBatchInput!) { - project(where: {id: $projectId}) { - %s(input: $batchInput) { - batch { - %s - } - failedDataRowIds - } - } - } - """ % (method, method, query.results_query_part(Entity.Batch)) - params = { - "projectId": self.uid, - "batchInput": { - "name": name, - "dataRowIds": dr_ids, - "globalKeys": global_keys, - "priority": priority, - "consensusSettings": consensus_settings - } - } - res = self.client.execute(query_str, - params, - timeout=180.0, - experimental=True)["project"][method] - batch = res['batch'] - batch['size'] = res['batch']['size'] - return Entity.Batch(self.client, - self.uid, - batch, - failed_data_row_ids=res['failedDataRowIds']) - - def _create_batch_async(self, - name: str, - dr_ids: Optional[List[str]] = None, - global_keys: Optional[List[str]] = None, - priority: int = 5, - consensus_settings: Optional[Dict[str, - float]] = None): - method = 'createEmptyBatch' - create_empty_batch_mutation_str = """mutation %sPyApi($projectId: ID!, $input: CreateEmptyBatchInput!) { - project(where: {id: $projectId}) { - %s(input: $input) { - id - } - } - } - """ % (method, method) - - params = { - "projectId": self.uid, - "input": { - "name": name, - "consensusSettings": consensus_settings - } - } - - res = self.client.execute(create_empty_batch_mutation_str, - params, - timeout=180.0, - experimental=True)["project"][method] - batch_id = res['id'] - - method = 'addDataRowsToBatchAsync' - add_data_rows_mutation_str = """mutation %sPyApi($projectId: ID!, $input: AddDataRowsToBatchInput!) { - project(where: {id: $projectId}) { - %s(input: $input) { - taskId - } - } - } - """ % (method, method) - - params = { - "projectId": self.uid, - "input": { - "batchId": batch_id, - "dataRowIds": dr_ids, - "globalKeys": global_keys, - "priority": priority, - } - } - - res = self.client.execute(add_data_rows_mutation_str, - params, - timeout=180.0, - experimental=True)["project"][method] - - task_id = res['taskId'] - - task = self._wait_for_task(task_id) - if task.status != "COMPLETE": - raise LabelboxError(f"Batch was not created successfully: " + - json.dumps(task.errors)) - - return self.client.get_batch(self.uid, batch_id) - - def _update_queue_mode(self, mode: "QueueMode") -> "QueueMode": - """ - Updates the queueing mode of this project. - - Deprecation notice: This method is deprecated. Going forward, projects must - go through a migration to have the queue mode changed. Users should specify the - queue mode for a project during creation if a non-default mode is desired. - - For more information, visit https://docs.labelbox.com/reference/migrating-to-workflows#upcoming-changes - - Args: - mode: the specified queue mode - - Returns: the updated queueing mode of this project - - """ - - logger.warning( - "Updating the queue_mode for a project will soon no longer be supported." - ) - - if self.queue_mode == mode: - return mode - - if mode == QueueMode.Batch: - status = "ENABLED" - elif mode == QueueMode.Dataset: - status = "DISABLED" - else: - raise ValueError( - "Must provide either `BATCH` or `DATASET` as a mode") - - query_str = """mutation %s($projectId: ID!, $status: TagSetStatusInput!) { - project(where: {id: $projectId}) { - setTagSetStatus(input: {tagSetStatus: $status}) { - tagSetStatus - } - } - } - """ % "setTagSetStatusPyApi" - - self.client.execute(query_str, { - 'projectId': self.uid, - 'status': status - }) - - return mode - - def get_label_count(self) -> int: - """ - Returns: the total number of labels in this project. - """ - - query_str = """query LabelCountPyApi($projectId: ID!) { - project(where: {id: $projectId}) { - labelCount - } - }""" - - res = self.client.execute(query_str, {'projectId': self.uid}) - return res["project"]["labelCount"] - - def get_queue_mode(self) -> "QueueMode": - """ - Provides the queue mode used for this project. - - Deprecation notice: This method is deprecated and will be removed in - a future version. To obtain the queue mode of a project, simply refer - to the queue_mode attribute of a Project. - - For more information, visit https://docs.labelbox.com/reference/migrating-to-workflows#upcoming-changes - - Returns: the QueueMode for this project - - """ - - logger.warning( - "Obtaining the queue_mode for a project through this method will soon" - " no longer be supported.") - - query_str = """query %s($projectId: ID!) { - project(where: {id: $projectId}) { - tagSetStatus - } - } - """ % "GetTagSetStatusPyApi" - - status = self.client.execute( - query_str, {'projectId': self.uid})["project"]["tagSetStatus"] - - if status == "ENABLED": - return QueueMode.Batch - elif status == "DISABLED": - return QueueMode.Dataset - else: - raise ValueError("Status not known") - - def set_labeling_parameter_overrides( - self, data: List[LabelingParameterOverrideInput]) -> bool: - """ Adds labeling parameter overrides to this project. - - See information on priority here: - https://docs.labelbox.com/en/configure-editor/queue-system#reservation-system - - >>> project.set_labeling_parameter_overrides([ - >>> (data_row_id1, 2), (data_row_id2, 1)]) - or - >>> project.set_labeling_parameter_overrides([ - >>> (data_row_gk1, 2), (data_row_gk2, 1)]) - - Args: - data (iterable): An iterable of tuples. Each tuple must contain - either (DataRow, DataRowPriority) - or (DataRowIdentifier, priority) for the new override. - DataRowIdentifier is an object representing a data row id or a global key. A DataIdentifier object can be a UniqueIds or GlobalKeys class. - NOTE - passing whole DatRow is deprecated. Please use a DataRowIdentifier instead. - - Priority: - * Data will be labeled in priority order. - - A lower number priority is labeled first. - - All signed 32-bit integers are accepted, from -2147483648 to 2147483647. - * Priority is not the queue position. - - The position is determined by the relative priority. - - E.g. [(data_row_1, 5,1), (data_row_2, 2,1), (data_row_3, 10,1)] - will be assigned in the following order: [data_row_2, data_row_1, data_row_3] - * The priority only effects items in the queue. - - Assigning a priority will not automatically add the item back into the queue. - Returns: - bool, indicates if the operation was a success. - """ - data = [t[:2] for t in data] - validate_labeling_parameter_overrides(data) - - template = Template( - """mutation SetLabelingParameterOverridesPyApi($$projectId: ID!) - {project(where: { id: $$projectId }) - {setLabelingParameterOverrides - (dataWithDataRowIdentifiers: [$dataWithDataRowIdentifiers]) - {success}}} - """) - - data_rows_with_identifiers = "" - for data_row, priority in data: - if isinstance(data_row, DataRow): - data_rows_with_identifiers += f"{{dataRowIdentifier: {{id: \"{data_row.uid}\", idType: {IdType.DataRowId}}}, priority: {priority}}}," - elif isinstance(data_row, UniqueId) or isinstance( - data_row, GlobalKey): - data_rows_with_identifiers += f"{{dataRowIdentifier: {{id: \"{data_row.key}\", idType: {data_row.id_type}}}, priority: {priority}}}," - else: - raise TypeError( - f"Data row identifier should be be of type DataRow or Data Row Identifier. Found {type(data_row)}." - ) - - query_str = template.substitute( - dataWithDataRowIdentifiers=data_rows_with_identifiers) - res = self.client.execute(query_str, {"projectId": self.uid}) - return res["project"]["setLabelingParameterOverrides"]["success"] - - @overload - def update_data_row_labeling_priority( - self, - data_rows: DataRowIdentifiers, - priority: int, - ) -> bool: - pass - - @overload - def update_data_row_labeling_priority( - self, - data_rows: List[str], - priority: int, - ) -> bool: - pass - - def update_data_row_labeling_priority( - self, - data_rows, - priority: int, - ) -> bool: - """ - Updates labeling parameter overrides to this project in bulk. This method allows up to 1 million data rows to be - updated at once. - - See information on priority here: - https://docs.labelbox.com/en/configure-editor/queue-system#reservation-system - - Args: - data_rows: a list of data row ids to update priorities for. This can be a list of strings or a DataRowIdentifiers object - DataRowIdentifier objects are lists of ids or global keys. A DataIdentifier object can be a UniqueIds or GlobalKeys class. - priority (int): Priority for the new override. See above for more information. - - Returns: - bool, indicates if the operation was a success. - """ - - if isinstance(data_rows, list): - data_rows = UniqueIds(data_rows) - warnings.warn("Using data row ids will be deprecated. Please use " - "UniqueIds or GlobalKeys instead.") - - method = "createQueuePriorityUpdateTask" - priority_param = "priority" - project_param = "projectId" - data_rows_param = "dataRowIdentifiers" - query_str = """mutation %sPyApi( - $%s: Int! - $%s: ID! - $%s: QueuePriorityUpdateDataRowIdentifiersInput - ) { - project(where: { id: $%s }) { - %s( - data: { priority: $%s, dataRowIdentifiers: $%s } - ) { - taskId - } - } - } - """ % (method, priority_param, project_param, data_rows_param, - project_param, method, priority_param, data_rows_param) - res = self.client.execute( - query_str, { - priority_param: priority, - project_param: self.uid, - data_rows_param: { - "ids": [id for id in data_rows], - "idType": data_rows.id_type, - }, - })["project"][method] - - task_id = res['taskId'] - - task = self._wait_for_task(task_id) - if task.status != "COMPLETE": - raise LabelboxError(f"Priority was not updated successfully: " + - json.dumps(task.errors)) - return True - - def extend_reservations(self, queue_type) -> int: - """ Extends all the current reservations for the current user on the given - queue type. - Args: - queue_type (str): Either "LabelingQueue" or "ReviewQueue" - Returns: - int, the number of reservations that were extended. - """ - if queue_type not in ("LabelingQueue", "ReviewQueue"): - raise InvalidQueryError("Unsupported queue type: %s" % queue_type) - - id_param = "projectId" - query_str = """mutation ExtendReservationsPyApi($%s: ID!){ - extendReservations(projectId:$%s queueType:%s)}""" % ( - id_param, id_param, queue_type) - res = self.client.execute(query_str, {id_param: self.uid}) - return res["extendReservations"] - - def enable_model_assisted_labeling(self, toggle: bool = True) -> bool: - """ Turns model assisted labeling either on or off based on input - - Args: - toggle (bool): True or False boolean - Returns: - True if toggled on or False if toggled off - """ - project_param = "project_id" - show_param = "show" - - query_str = """mutation toggle_model_assisted_labelingPyApi($%s: ID!, $%s: Boolean!) { - project(where: {id: $%s }) { - showPredictionsToLabelers(show: $%s) { - id, showingPredictionsToLabelers - } - } - }""" % (project_param, show_param, project_param, show_param) - - params = {project_param: self.uid, show_param: toggle} - - res = self.client.execute(query_str, params) - return res["project"]["showPredictionsToLabelers"][ - "showingPredictionsToLabelers"] - - def bulk_import_requests(self) -> PaginatedCollection: - """ Returns bulk import request objects which are used in model-assisted labeling. - These are returned with the oldest first, and most recent last. - """ - - id_param = "project_id" - query_str = """query ListAllImportRequestsPyApi($%s: ID!) { - bulkImportRequests ( - where: { projectId: $%s } - skip: %%d - first: %%d - ) { - %s - } - }""" % (id_param, id_param, - query.results_query_part(Entity.BulkImportRequest)) - return PaginatedCollection(self.client, query_str, - {id_param: str(self.uid)}, - ["bulkImportRequests"], - Entity.BulkImportRequest) - - def batches(self) -> PaginatedCollection: - """ Fetch all batches that belong to this project - - Returns: - A `PaginatedCollection` of `Batch`es - """ - id_param = "projectId" - query_str = """query GetProjectBatchesPyApi($from: String, $first: PageSize, $%s: ID!) { - project(where: {id: $%s}) {id - batches(after: $from, first: $first) { nodes { %s } pageInfo { endCursor }}}} - """ % (id_param, id_param, query.results_query_part(Entity.Batch)) - return PaginatedCollection( - self.client, - query_str, {id_param: self.uid}, ['project', 'batches', 'nodes'], - lambda client, res: Entity.Batch(client, self.uid, res), - cursor_path=['project', 'batches', 'pageInfo', 'endCursor'], - experimental=True) - - def task_queues(self) -> List[TaskQueue]: - """ Fetch all task queues that belong to this project - - Returns: - A `List` of `TaskQueue`s - """ - query_str = """query GetProjectTaskQueuesPyApi($projectId: ID!) { - project(where: {id: $projectId}) { - taskQueues { - %s - } - } - } - """ % (query.results_query_part(Entity.TaskQueue)) - - task_queue_values = self.client.execute( - query_str, {"projectId": self.uid}, - timeout=180.0, - experimental=True)["project"]["taskQueues"] - - return [ - Entity.TaskQueue(self.client, field_values) - for field_values in task_queue_values - ] - - @overload - def move_data_rows_to_task_queue(self, data_row_ids: DataRowIdentifiers, - task_queue_id: str): - pass - - @overload - def move_data_rows_to_task_queue(self, data_row_ids: List[str], - task_queue_id: str): - pass - - def move_data_rows_to_task_queue(self, data_row_ids, task_queue_id: str): - """ - - Moves data rows to the specified task queue. - - Args: - data_row_ids: a list of data row ids to be moved. This can be a list of strings or a DataRowIdentifiers object - DataRowIdentifier objects are lists of ids or global keys. A DataIdentifier object can be a UniqueIds or GlobalKeys class. - task_queue_id: the task queue id to be moved to, or None to specify the "Done" queue - - Returns: - None if successful, or a raised error on failure - - """ - if isinstance(data_row_ids, list): - data_row_ids = UniqueIds(data_row_ids) - warnings.warn("Using data row ids will be deprecated. Please use " - "UniqueIds or GlobalKeys instead.") - - method = "createBulkAddRowsToQueueTask" - query_str = """mutation AddDataRowsToTaskQueueAsyncPyApi( - $projectId: ID! - $queueId: ID - $dataRowIdentifiers: AddRowsToTaskQueueViaDataRowIdentifiersInput! - ) { - project(where: { id: $projectId }) { - %s( - data: { queueId: $queueId, dataRowIdentifiers: $dataRowIdentifiers } - ) { - taskId - } - } - } - """ % method - - task_id = self.client.execute( - query_str, { - "projectId": self.uid, - "queueId": task_queue_id, - "dataRowIdentifiers": { - "ids": [id for id in data_row_ids], - "idType": data_row_ids.id_type, - }, - }, - timeout=180.0, - experimental=True)["project"][method]["taskId"] - - task = self._wait_for_task(task_id) - if task.status != "COMPLETE": - raise LabelboxError(f"Data rows were not moved successfully: " + - json.dumps(task.errors)) - - def _wait_for_task(self, task_id: str) -> Task: - task = Task.get_task(self.client, task_id) - task.wait_till_done() - - return task - - def upload_annotations( - self, - name: str, - annotations: Union[str, Path, Iterable[Dict]], - validate: bool = False) -> 'BulkImportRequest': # type: ignore - """ Uploads annotations to a new Editor project. - - Args: - name (str): name of the BulkImportRequest job - annotations (str or Path or Iterable): - url that is publicly accessible by Labelbox containing an - ndjson file - OR local path to an ndjson file - OR iterable of annotation rows - validate (bool): - Whether or not to validate the payload before uploading. - Returns: - BulkImportRequest - """ - - if isinstance(annotations, str) or isinstance(annotations, Path): - - def _is_url_valid(url: Union[str, Path]) -> bool: - """ Verifies that the given string is a valid url. - - Args: - url: string to be checked - Returns: - True if the given url is valid otherwise False - - """ - if isinstance(url, Path): - return False - parsed = urlparse(url) - return bool(parsed.scheme) and bool(parsed.netloc) - - if _is_url_valid(annotations): - return Entity.BulkImportRequest.create_from_url( - client=self.client, - project_id=self.uid, - name=name, - url=str(annotations), - validate=validate) - else: - path = Path(annotations) - if not path.exists(): - raise FileNotFoundError( - f'{annotations} is not a valid url nor existing local file' - ) - return Entity.BulkImportRequest.create_from_local_file( - client=self.client, - project_id=self.uid, - name=name, - file=path, - validate_file=validate, - ) - elif isinstance(annotations, Iterable): - return Entity.BulkImportRequest.create_from_objects( - client=self.client, - project_id=self.uid, - name=name, - predictions=annotations, # type: ignore - validate=validate) - else: - raise ValueError( - f'Invalid annotations given of type: {type(annotations)}') - - def _wait_until_data_rows_are_processed( - self, - data_row_ids: Optional[List[str]] = None, - global_keys: Optional[List[str]] = None, - wait_processing_max_seconds: int = _wait_processing_max_seconds, - sleep_interval=30): - """ Wait until all the specified data rows are processed""" - start_time = datetime.now() - - max_data_rows_per_poll = 100_000 - if data_row_ids is not None: - for i in range(0, len(data_row_ids), max_data_rows_per_poll): - chunk = data_row_ids[i:i + max_data_rows_per_poll] - self._poll_data_row_processing_status( - chunk, [], start_time, wait_processing_max_seconds, - sleep_interval) - - if global_keys is not None: - for i in range(0, len(global_keys), max_data_rows_per_poll): - chunk = global_keys[i:i + max_data_rows_per_poll] - self._poll_data_row_processing_status( - [], chunk, start_time, wait_processing_max_seconds, - sleep_interval) - - def _poll_data_row_processing_status( - self, - data_row_ids: List[str], - global_keys: List[str], - start_time: datetime, - wait_processing_max_seconds: int = _wait_processing_max_seconds, - sleep_interval=30): - - while True: - if (datetime.now() - - start_time).total_seconds() >= wait_processing_max_seconds: - raise ProcessingWaitTimeout( - """Maximum wait time exceeded while waiting for data rows to be processed. - Try creating a batch a bit later""") - - all_good = self.__check_data_rows_have_been_processed( - data_row_ids, global_keys) - if all_good: - return - - logger.debug( - 'Some of the data rows are still being processed, waiting...') - time.sleep(sleep_interval) - - def __check_data_rows_have_been_processed( - self, - data_row_ids: Optional[List[str]] = None, - global_keys: Optional[List[str]] = None): - - if data_row_ids is not None and len(data_row_ids) > 0: - param_name = "dataRowIds" - params = {param_name: data_row_ids} - else: - param_name = "globalKeys" - global_keys = global_keys if global_keys is not None else [] - params = {param_name: global_keys} - - query_str = """query CheckAllDataRowsHaveBeenProcessedPyApi($%s: [ID!]) { - queryAllDataRowsHaveBeenProcessed(%s:$%s) { - allDataRowsHaveBeenProcessed - } - }""" % (param_name, param_name, param_name) - - response = self.client.execute(query_str, params) - return response["queryAllDataRowsHaveBeenProcessed"][ - "allDataRowsHaveBeenProcessed"] - - -class ProjectMember(DbObject): - user = Relationship.ToOne("User", cache=True) - role = Relationship.ToOne("Role", cache=True) - access_from = Field.String("access_from") - - -class LabelingParameterOverride(DbObject): - """ Customizes the order of assets in the label queue. - - Attributes: - priority (int): A prioritization score. - number_of_labels (int): Number of times an asset should be labeled. - """ - priority = Field.Int("priority") - number_of_labels = Field.Int("number_of_labels") - - data_row = Relationship.ToOne("DataRow", cache=True) - - -LabelerPerformance = namedtuple( - "LabelerPerformance", "user count seconds_per_label, total_time_labeling " - "consensus average_benchmark_agreement last_activity_time") -LabelerPerformance.__doc__ = ( - "Named tuple containing info about a labeler's performance.") - - -def _check_converter_import(): - if 'LBV1Converter' not in globals(): - raise ImportError( - "Missing dependencies to import converter. " - "Use `pip install labelbox[data] --upgrade` to add missing dependencies. " - "or download raw json with project.export_labels()") - ----- -labelbox/schema/export_task.py -from abc import ABC, abstractmethod -from dataclasses import dataclass -from enum import Enum -from functools import lru_cache -from io import TextIOWrapper -import json -from pathlib import Path -from typing import ( - Callable, - Generic, - Iterator, - List, - Optional, - Tuple, - TypeVar, - Union, - TYPE_CHECKING, - overload, -) - -import requests -from labelbox import pydantic_compat - -from labelbox.schema.task import Task -from labelbox.utils import _CamelCaseMixin - -if TYPE_CHECKING: - from labelbox import Client - -OutputT = TypeVar("OutputT") - - -class StreamType(Enum): - """The type of the stream.""" - - RESULT = "RESULT" - ERRORS = "ERRORS" - - -class Range(_CamelCaseMixin, pydantic_compat.BaseModel): # pylint: disable=too-few-public-methods - """Represents a range.""" - - start: int - end: int - - -class _MetadataHeader(_CamelCaseMixin, pydantic_compat.BaseModel): # pylint: disable=too-few-public-methods - total_size: int - total_lines: int - - -class _MetadataFileInfo(_CamelCaseMixin, pydantic_compat.BaseModel): # pylint: disable=too-few-public-methods - offsets: Range - lines: Range - file: str - - -@dataclass -class _TaskContext: - client: "Client" - task_id: str - stream_type: StreamType - metadata_header: _MetadataHeader - - -class Converter(ABC, Generic[OutputT]): - """Abstract class for transforming data.""" - - @dataclass - class ConverterInputArgs: - """Input for the converter.""" - - ctx: _TaskContext - file_info: _MetadataFileInfo - raw_data: str - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - return False - - @abstractmethod - def convert(self, input_args: ConverterInputArgs) -> Iterator[OutputT]: - """Converts the data. - Returns an iterator that yields the converted data. - - Args: - current_offset: The global offset indicating the position of the data within the - exported files. It represents a cumulative offset in characters - across multiple files. - raw_data: The raw data to convert. - - Yields: - Iterator[OutputT]: The converted data. - """ - - -@dataclass -class JsonConverterOutput: - """Output with the JSON string.""" - - current_offset: int - current_line: int - json_str: str - - -class JsonConverter(Converter[JsonConverterOutput]): # pylint: disable=too-few-public-methods - """Converts JSON data.""" - - def _find_json_object_offsets(self, data: str) -> List[Tuple[int, int]]: - object_offsets: List[Tuple[int, int]] = [] - stack = [] - current_object_start = None - - for index, char in enumerate(data): - if char == "{": - stack.append(char) - if len(stack) == 1: - current_object_start = index - # we need to account for scenarios where data lands in the middle of an object - # and the object is not the last one in the data - if index > 0 and data[index - - 1] == "\n" and not object_offsets: - object_offsets.append((0, index - 1)) - elif char == "}" and stack: - stack.pop() - # this covers cases where the last object is either followed by a newline or - # it is missing - if len(stack) == 0 and (len(data) == index + 1 or - data[index + 1] == "\n" - ) and current_object_start is not None: - object_offsets.append((current_object_start, index + 1)) - current_object_start = None - - # we also need to account for scenarios where data lands in the middle of the last object - return object_offsets if object_offsets else [(0, len(data) - 1)] - - def convert( - self, input_args: Converter.ConverterInputArgs - ) -> Iterator[JsonConverterOutput]: - current_offset, current_line, raw_data = ( - input_args.file_info.offsets.start, - input_args.file_info.lines.start, - input_args.raw_data, - ) - offsets = self._find_json_object_offsets(raw_data) - for line, (offset_start, offset_end) in enumerate(offsets): - yield JsonConverterOutput( - current_offset=current_offset + offset_start, - current_line=current_line + line, - json_str=raw_data[offset_start:offset_end + 1].strip(), - ) - - -@dataclass -class FileConverterOutput: - """Output with statistics about the written file.""" - - file_path: Path - total_size: int - total_lines: int - current_offset: int - current_line: int - bytes_written: int - - -class FileConverter(Converter[FileConverterOutput]): - """Converts data to a file.""" - - def __init__(self, file_path: str) -> None: - super().__init__() - self._file: Optional[TextIOWrapper] = None - self._file_path = file_path - - def __enter__(self): - self._file = open(self._file_path, "w", encoding="utf-8") - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if self._file: - self._file.close() - return False - - def convert( - self, input_args: Converter.ConverterInputArgs - ) -> Iterator[FileConverterOutput]: - # appends data to the file - assert self._file is not None - self._file.write(input_args.raw_data) - yield FileConverterOutput( - file_path=Path(self._file_path), - total_size=input_args.ctx.metadata_header.total_size, - total_lines=input_args.ctx.metadata_header.total_lines, - current_offset=input_args.file_info.offsets.start, - current_line=input_args.file_info.lines.start, - bytes_written=len(input_args.raw_data), - ) - - -class FileRetrieverStrategy(ABC): # pylint: disable=too-few-public-methods - """Abstract class for retrieving files.""" - - def __init__(self, ctx: _TaskContext) -> None: - super().__init__() - self._ctx = ctx - - @abstractmethod - def get_next_chunk(self) -> Optional[Tuple[_MetadataFileInfo, str]]: - """Retrieves the file.""" - - def _get_file_content( - self, query: str, variables: dict, - result_field_name: str) -> Tuple[_MetadataFileInfo, str]: - """Runs the query.""" - res = self._ctx.client.execute(query, variables, error_log_key="errors") - res = res["task"][result_field_name] - file_info = _MetadataFileInfo(**res) if res else None - if not file_info: - raise ValueError( - f"Task {self._ctx.task_id} does not have a metadata file for the " - f"{self._ctx.stream_type.value} stream") - response = requests.get(file_info.file, timeout=30) - response.raise_for_status() - assert len( - response.text - ) == file_info.offsets.end - file_info.offsets.start + 1, ( - f"expected {file_info.offsets.end - file_info.offsets.start + 1} bytes, " - f"got {len(response.text)} bytes") - return file_info, response.text - - -class FileRetrieverByOffset(FileRetrieverStrategy): # pylint: disable=too-few-public-methods - """Retrieves files by offset.""" - - def __init__( - self, - ctx: _TaskContext, - offset: int, - ) -> None: - super().__init__(ctx) - self._current_offset = offset - self._current_line: Optional[int] = None - if self._current_offset >= self._ctx.metadata_header.total_size: - raise ValueError( - f"offset is out of range, max offset is {self._ctx.metadata_header.total_size - 1}" - ) - - def _find_line_at_offset(self, file_content: str, - target_offset: int) -> int: - stack = [] - line_number = 0 - - for index, char in enumerate(file_content): - if char == "{": - stack.append(char) - if len(stack) == 1 and index > 0: - line_number += 1 - elif char == "}" and stack: - stack.pop() - - if index == target_offset: - break - - return line_number - - def get_next_chunk(self) -> Optional[Tuple[_MetadataFileInfo, str]]: - if self._current_offset >= self._ctx.metadata_header.total_size: - return None - query = ( - f"query GetExportFileFromOffsetPyApi" - f"($where: WhereUniqueIdInput, $streamType: TaskStreamType!, $offset: UInt64!)" - f"{{task(where: $where)" - f"{{{'exportFileFromOffset'}(streamType: $streamType, offset: $offset)" - f"{{offsets {{start end}} lines {{start end}} file}}" - f"}}}}") - variables = { - "where": { - "id": self._ctx.task_id - }, - "streamType": self._ctx.stream_type.value, - "offset": str(self._current_offset), - } - file_info, file_content = self._get_file_content( - query, variables, "exportFileFromOffset") - if self._current_line is None: - self._current_line = self._find_line_at_offset( - file_content, self._current_offset - file_info.offsets.start) - self._current_line += file_info.lines.start - file_content = file_content[self._current_offset - - file_info.offsets.start:] - file_info.offsets.start = self._current_offset - file_info.lines.start = self._current_line - self._current_offset = file_info.offsets.end + 1 - self._current_line = file_info.lines.end + 1 - return file_info, file_content - - -class FileRetrieverByLine(FileRetrieverStrategy): # pylint: disable=too-few-public-methods - """Retrieves files by line.""" - - def __init__( - self, - ctx: _TaskContext, - line: int, - ) -> None: - super().__init__(ctx) - self._current_line = line - self._current_offset: Optional[int] = None - if self._current_line >= self._ctx.metadata_header.total_lines: - raise ValueError( - f"line is out of range, max line is {self._ctx.metadata_header.total_lines - 1}" - ) - - def _find_offset_of_line(self, file_content: str, target_line: int): - start_offset = None - stack = [] - line_number = 0 - - for index, char in enumerate(file_content): - if char == "{": - stack.append(char) - if len(stack) == 1: - if line_number == target_line: - start_offset = index - line_number += 1 - elif char == "}" and stack: - stack.pop() - - if line_number > target_line: - break - - return start_offset - - def get_next_chunk(self) -> Optional[Tuple[_MetadataFileInfo, str]]: - if self._current_line >= self._ctx.metadata_header.total_lines: - return None - query = ( - f"query GetExportFileFromLinePyApi" - f"($where: WhereUniqueIdInput, $streamType: TaskStreamType!, $line: UInt64!)" - f"{{task(where: $where)" - f"{{{'exportFileFromLine'}(streamType: $streamType, line: $line)" - f"{{offsets {{start end}} lines {{start end}} file}}" - f"}}}}") - variables = { - "where": { - "id": self._ctx.task_id - }, - "streamType": self._ctx.stream_type.value, - "line": self._current_line, - } - file_info, file_content = self._get_file_content( - query, variables, "exportFileFromLine") - if self._current_offset is None: - self._current_offset = self._find_offset_of_line( - file_content, self._current_line - file_info.lines.start) - self._current_offset += file_info.offsets.start - file_content = file_content[self._current_offset - - file_info.offsets.start:] - file_info.offsets.start = self._current_offset - file_info.lines.start = self._current_line - self._current_offset = file_info.offsets.end + 1 - self._current_line = file_info.lines.end + 1 - return file_info, file_content - - -class _Reader(ABC): # pylint: disable=too-few-public-methods - """Abstract class for reading data from a source.""" - - @abstractmethod - def set_retrieval_strategy(self, strategy: FileRetrieverStrategy) -> None: - """Sets the retrieval strategy.""" - - @abstractmethod - def read(self) -> Iterator[Tuple[_MetadataFileInfo, str]]: - """Reads data from the source.""" - - -class _MultiGCSFileReader(_Reader): # pylint: disable=too-few-public-methods - """Reads data from multiple GCS files in a seamless way.""" - - def __init__(self): - super().__init__() - self._retrieval_strategy = None - - def set_retrieval_strategy(self, strategy: FileRetrieverStrategy) -> None: - """Sets the retrieval strategy.""" - self._retrieval_strategy = strategy - - def read(self) -> Iterator[Tuple[_MetadataFileInfo, str]]: - if not self._retrieval_strategy: - raise ValueError("retrieval strategy not set") - result = self._retrieval_strategy.get_next_chunk() - while result: - file_info, raw_data = result - yield file_info, raw_data - result = self._retrieval_strategy.get_next_chunk() - - -class Stream(Generic[OutputT]): - """Streams data from a Reader.""" - - def __init__( - self, - ctx: _TaskContext, - reader: _Reader, - converter: Converter, - ): - self._ctx = ctx - self._reader = reader - self._converter = converter - # default strategy is to retrieve files by offset, starting from 0 - self.with_offset(0) - - def __iter__(self): - yield from self._fetch() - - def _fetch(self,) -> Iterator[OutputT]: - """Fetches the result data. - Returns an iterator that yields the offset and the data. - """ - if self._ctx.metadata_header.total_size is None: - return - - stream = self._reader.read() - with self._converter as converter: - for file_info, raw_data in stream: - for output in converter.convert( - Converter.ConverterInputArgs(self._ctx, file_info, - raw_data)): - yield output - - def with_offset(self, offset: int) -> "Stream[OutputT]": - """Sets the offset for the stream.""" - self._reader.set_retrieval_strategy( - FileRetrieverByOffset(self._ctx, offset)) - return self - - def with_line(self, line: int) -> "Stream[OutputT]": - """Sets the line number for the stream.""" - self._reader.set_retrieval_strategy(FileRetrieverByLine( - self._ctx, line)) - return self - - def start( - self, - stream_handler: Optional[Callable[[OutputT], None]] = None) -> None: - """Starts streaming the result data. - Calls the stream_handler for each result. - """ - # this calls the __iter__ method, which in turn calls the _fetch method - for output in self: - if stream_handler: - stream_handler(output) - - -class ExportTask: - """ - An adapter class for working with task objects, providing extended functionality - and convenient access to task-related information. - - This class wraps a `Task` object, allowing you to interact with tasks of this type. - It offers methods to retrieve task results, errors, and metadata, as well as properties - for accessing task details such as UID, status, and creation time. - """ - - class ExportTaskException(Exception): - """Raised when the task is not ready yet.""" - - def __init__(self, task: Task) -> None: - self._task = task - - def __repr__(self): - return f"" if getattr( - self, "uid", None) else "" - - def __str__(self): - properties_to_include = [ - "completion_percentage", - "created_at", - "metadata", - "name", - "result", - "status", - "type", - "uid", - "updated_at", - ] - props = {prop: getattr(self, prop) for prop in properties_to_include} - return f"" - - def __eq__(self, other): - return self._task.__eq__(other) - - def __hash__(self): - return self._task.__hash__() - - @property - def uid(self): - """Returns the uid of the task.""" - return self._task.uid - - @property - def deleted(self): - """Returns whether the task is deleted.""" - return self._task.deleted - - @property - def updated_at(self): - """Returns the last time the task was updated.""" - return self._task.updated_at - - @property - def created_at(self): - """Returns the time the task was created.""" - return self._task.created_at - - @property - def name(self): - """Returns the name of the task.""" - return self._task.name - - @property - def status(self): - """Returns the status of the task.""" - return self._task.status - - @property - def metadata(self): - """Returns the metadata of the task.""" - return self._task.metadata - - @property - def result(self): - """Returns the result of the task.""" - return self._task.result_url - - @property - def completion_percentage(self): - """Returns the completion percentage of the task.""" - return self._task.completion_percentage - - @property - def type(self): - """Returns the type of the task.""" - return self._task.type - - @property - def created_by(self): - """Returns the user who created the task.""" - return self._task.created_by - - @property - def organization(self): - """Returns the organization of the task.""" - return self._task.organization - - def wait_till_done(self, timeout_seconds: int = 300) -> None: - """Waits until the task is done.""" - return self._task.wait_till_done(timeout_seconds) - - @staticmethod - @lru_cache(maxsize=5) - def _get_metadata_header( - client, task_id: str, - stream_type: StreamType) -> Union[_MetadataHeader, None]: - """Returns the total file size for a specific task.""" - query = (f"query GetExportMetadataHeaderPyApi" - f"($where: WhereUniqueIdInput, $streamType: TaskStreamType!)" - f"{{task(where: $where)" - f"{{{'exportMetadataHeader'}(streamType: $streamType)" - f"{{totalSize totalLines}}" - f"}}}}") - variables = {"where": {"id": task_id}, "streamType": stream_type.value} - res = client.execute(query, variables, error_log_key="errors") - res = res["task"]["exportMetadataHeader"] - return _MetadataHeader(**res) if res else None - - def get_total_file_size(self, stream_type: StreamType) -> Union[int, None]: - """Returns the total file size for a specific task.""" - if self._task.status == "FAILED": - raise ExportTask.ExportTaskException("Task failed") - if self._task.status != "COMPLETE": - raise ExportTask.ExportTaskException("Task is not ready yet") - header = ExportTask._get_metadata_header(self._task.client, - self._task.uid, stream_type) - return header.total_size if header else None - - def get_total_lines(self, stream_type: StreamType) -> Union[int, None]: - """Returns the total file size for a specific task.""" - if self._task.status == "FAILED": - raise ExportTask.ExportTaskException("Task failed") - if self._task.status != "COMPLETE": - raise ExportTask.ExportTaskException("Task is not ready yet") - header = ExportTask._get_metadata_header(self._task.client, - self._task.uid, stream_type) - return header.total_lines if header else None - - def has_result(self) -> bool: - """Returns whether the task has a result.""" - total_size = self.get_total_file_size(StreamType.RESULT) - return total_size is not None and total_size > 0 - - def has_errors(self) -> bool: - """Returns whether the task has errors.""" - total_size = self.get_total_file_size(StreamType.ERRORS) - return total_size is not None and total_size > 0 - - @overload - def get_stream( - self, - converter: JsonConverter = JsonConverter(), - stream_type: StreamType = StreamType.RESULT, - ) -> Stream[JsonConverterOutput]: - """Overload for getting the right typing hints when using a JsonConverter.""" - - @overload - def get_stream( - self, - converter: FileConverter, - stream_type: StreamType = StreamType.RESULT, - ) -> Stream[FileConverterOutput]: - """Overload for getting the right typing hints when using a FileConverter.""" - - def get_stream( - self, - converter: Converter = JsonConverter(), - stream_type: StreamType = StreamType.RESULT, - ) -> Stream: - """Returns the result of the task.""" - if self._task.status == "FAILED": - raise ExportTask.ExportTaskException("Task failed") - if self._task.status != "COMPLETE": - raise ExportTask.ExportTaskException("Task is not ready yet") - - metadata_header = self._get_metadata_header(self._task.client, - self._task.uid, stream_type) - if metadata_header is None: - raise ValueError( - f"Task {self._task.uid} does not have a {stream_type.value} stream" - ) - return Stream( - _TaskContext(self._task.client, self._task.uid, stream_type, - metadata_header), - _MultiGCSFileReader(), - converter, - ) - - @staticmethod - def get_task(client, task_id): - """Returns the task with the given id.""" - return ExportTask(Task.get_task(client, task_id)) - ----- -labelbox/schema/foundry/__init__.py - ----- -labelbox/schema/foundry/model.py -from labelbox.utils import _CamelCaseMixin - -from labelbox import pydantic_compat - -from datetime import datetime -from typing import Dict - - -class Model(_CamelCaseMixin, pydantic_compat.BaseModel): - id: str - description: str - inference_params_json_schema: Dict - name: str - ontology_id: str - created_at: datetime - - -MODEL_FIELD_NAMES = list(Model.schema()['properties'].keys()) - ----- -labelbox/schema/foundry/app.py -from labelbox.utils import _CamelCaseMixin - -from labelbox import pydantic_compat - -from typing import Any, Dict, Optional - - -class App(_CamelCaseMixin, pydantic_compat.BaseModel): - id: Optional[str] - model_id: str - name: str - description: Optional[str] = None - inference_params: Dict[str, Any] - class_to_schema_id: Dict[str, str] - ontology_id: str - created_by: Optional[str] = None - - @classmethod - def type_name(cls): - return "App" - - -APP_FIELD_NAMES = list(App.schema()['properties'].keys()) - ----- -labelbox/schema/foundry/foundry_client.py -from typing import Union -from labelbox import exceptions -from labelbox.schema.foundry.app import App, APP_FIELD_NAMES -from labelbox.schema.identifiables import DataRowIds, GlobalKeys, IdType -from labelbox.schema.task import Task - - -class FoundryClient: - - def __init__(self, client): - self.client = client - - def _create_app(self, app: App) -> App: - field_names_str = "\n".join(APP_FIELD_NAMES) - query_str = f""" - mutation CreateFoundryAppPyApi( - $name: String!, $modelId: ID!, $ontologyId: ID!, $description: String, $inferenceParams: Json!, $classToSchemaId: Json! - ){{ - createModelFoundryApp(input: {{ - name: $name - modelId: $modelId - ontologyId: $ontologyId - description: $description - inferenceParams: $inferenceParams - classToSchemaId: $classToSchemaId - }}) - {{ - {field_names_str} - }} - }} - """ - - params = app.dict(by_alias=True, exclude={"id"}) - - try: - response = self.client.execute(query_str, params) - except exceptions.LabelboxError as e: - raise exceptions.LabelboxError('Unable to create app', e) - return App(**response["createModelFoundryApp"]) - - def _get_app(self, id: str) -> App: - field_names_str = "\n".join(APP_FIELD_NAMES) - - query_str = f""" - query GetFoundryAppByIdPyApi($id: ID!) {{ - findModelFoundryApp(where: {{id: $id}}) {{ - {field_names_str} - }} - }} - """ - params = {"id": id} - - try: - response = self.client.execute(query_str, params) - except exceptions.InvalidQueryError as e: - raise exceptions.ResourceNotFoundError(App, params) - except Exception as e: - raise exceptions.LabelboxError(f'Unable to get app with id {id}', e) - return App(**response["findModelFoundryApp"]) - - def _delete_app(self, id: str) -> None: - query_str = """ - mutation DeleteFoundryAppPyApi($id: ID!) { - deleteModelFoundryApp(id: $id) { - success - } - } - """ - params = {"id": id} - try: - self.client.execute(query_str, params) - except Exception as e: - raise exceptions.LabelboxError(f'Unable to delete app with id {id}', - e) - - def run_app(self, model_run_name: str, - data_rows: Union[DataRowIds, GlobalKeys], app_id: str) -> Task: - app = self._get_app(app_id) - - params = { - "modelId": str(app.model_id), - "name": model_run_name, - "classToSchemaId": app.class_to_schema_id, - "inferenceParams": app.inference_params, - "ontologyId": app.ontology_id - } - - data_rows_key = "dataRowIds" if data_rows.id_type == IdType.DataRowId else "globalKeys" - params[data_rows_key] = list(data_rows) - - query = """ - mutation CreateModelJobPyApi($input: CreateModelJobForDataRowsInput!) { - createModelJobForDataRows(input: $input) { - taskId - __typename - } - } - """ - try: - response = self.client.execute(query, {"input": params}) - except Exception as e: - raise exceptions.LabelboxError('Unable to run foundry app', e) - task_id = response["createModelJobForDataRows"]["taskId"] - return Task.get_task(self.client, task_id) - ----- -labelbox/orm/query.py -from itertools import chain -from typing import Any, Dict - -from labelbox import utils -from labelbox.exceptions import InvalidQueryError, InvalidAttributeError, MalformedQueryException -from labelbox.orm.comparison import LogicalExpression, Comparison -from labelbox.orm.model import Field, Relationship, Entity -""" Common query creation functionality. """ - - -def format_param_declaration(params): - """ Formats the parameters dictionary into a declaration of GraphQL - query parameters. - - Args: - params (dict): keys are query param names and values are - (value, (field|relationship)) tuples. - Return: - str, the declaration of query parameters. - """ - if not params: - return "" - - def attr_type(attr): - if isinstance(attr, Field): - return attr.field_type.name - else: - return Field.Type.ID.name - - return "(" + ", ".join("$%s: %s!" % (param, attr_type(attr)) - for param, (_, attr) in params.items()) + ")" - - -def results_query_part(entity): - """ Generates the results part of the query. The results contain - all the entity's fields as well as prefetched relationships. - - Note that this is a recursive function. If there is a cycle in the - prefetched relationship graph, this function will recurse infinitely. - - Args: - entity (type): The entity which needs fetching. - """ - # Query for fields - fields = [ - field.result_subquery - if field.result_subquery is not None else field.graphql_name - for field in entity.fields() - ] - - # Query for cached relationships - fields.extend([ - Query(rel.graphql_name, rel.destination_type).format()[0] - for rel in entity.relationships() - if rel.cache - ]) - return " ".join(fields) - - -class Query: - """ A data structure used during the construction of a query. Supports - subquery (also Query object) nesting for relationship. """ - - def __init__(self, - what, - subquery, - where=None, - paginate=False, - order_by=None): - """ Initializer. - Args: - what (str): What is being queried. Typically an object type in - singular or plural (i.e. "project" or "projects"). - subquery (Query or type): Either a Query object that is formatted - recursively or a Entity subtype in which case the standard - results (see `results_query_part`) are retrieved. - where (None, Comparison or LogicalExpression): the filtering clause. - paginate (bool): If the "%skip %first" pagination substring should - be added to the query. Used for collection pagination in combination - with PaginatedCollection. - order_by (tuple): A tuple consisting of (Field, Field.Order) indicating - how the query should sort the collection. - """ - self.what = what - self.subquery = subquery - self.paginate = paginate - self.where = where - self.order_by = order_by - - def format_subquery(self): - """ Formats the subquery (a Query or Entity subtype). """ - if isinstance(self.subquery, Query): - return self.subquery.format() - elif issubclass(self.subquery, Entity): - return results_query_part(self.subquery), {} - else: - raise MalformedQueryException() - - def format_clauses(self, params): - """ Formats the where, order_by and pagination clauses. - Args: - params (dict): The current parameter dictionary. - """ - - def format_where(node): - """ Helper that resursively constructs a where clause from a - LogicalExpression tree (leaf nodes are Comparisons). """ - COMPARISON_TO_SUFFIX = { - Comparison.Op.EQ: "", - Comparison.Op.NE: "_not", - Comparison.Op.LT: "_lt", - Comparison.Op.GT: "_gt", - Comparison.Op.LE: "_lte", - Comparison.Op.GE: "_gte", - } - assert isinstance(node, (Comparison, LogicalExpression)) - if isinstance(node, Comparison): - param_name = "param_%d" % len(params) - params[param_name] = (node.value, node.field) - return "{%s%s: $%s}" % (node.field.graphql_name, - COMPARISON_TO_SUFFIX[node.op], - param_name) - if node.op == LogicalExpression.Op.NOT: - return "{NOT: [%s]}" % format_where(node.first) - - return "{%s: [%s, %s]}" % (node.op.name.upper(), - format_where(node.first), - format_where(node.second)) - - paginate = "skip: %d first: %d" if self.paginate else "" - - where = "where: %s" % format_where(self.where) if self.where else "" - - if self.order_by: - order_by = "orderBy: %s_%s" % (self.order_by[0].graphql_name, - self.order_by[1].name.upper()) - else: - order_by = "" - - clauses = " ".join(filter(None, (where, paginate, order_by))) - return "(" + clauses + ")" if clauses else "" - - def format(self): - """ Formats the full query but without "query" prefix, query name - and parameter declaration. - Return: - (str, dict) tuple. str is the query and dict maps parameter - names to (value, field) tuples. - """ - subquery, params = self.format_subquery() - clauses = self.format_clauses(params) - query = "%s%s{%s}" % (self.what, clauses, subquery) - return query, params - - def format_top(self, name): - """ Formats the full query including "query" prefix, query name - and parameter declaration. The result of this function can be - sent to the Client object for execution. - - Args: - name (str): Query name, without the "PyApi" suffix, it's appended - automatically by this method. - Return: - (str, dict) tuple. str is the full query and dict maps parameter - names to parameter values. - """ - query, params = self.format() - param_declaration = format_param_declaration(params) - query = "query %sPyApi%s{%s}" % (name, param_declaration, query) - return query, {param: value for param, (value, _) in params.items()} - - -def get_single(entity, uid): - """ Constructs the query and params dict for obtaining a single object. Either - on ID, or without params. - Args: - entity (type): An Entity subtype being obtained. - uid (str): The ID of the sought object. It can be None, which is legal for - DB types that have a default object being returned (User and - Organization). - """ - type_name = entity.type_name() - where = entity.uid == uid if uid else None - return Query(utils.camel_case(type_name), entity, - where).format_top("Get" + type_name) - - -def logical_ops(where): - """ Returns a generator that yields all the logical operator - type objects (`LogicalExpression.Op` instances) from a where - clause. - - Args: - where (LogicalExpression, Comparison or None): The where - clause used for filtering in a query. - Return: - See above. - """ - if isinstance(where, LogicalExpression): - yield where.op - for f in chain(logical_ops(where.first), logical_ops(where.second)): - yield f - - -def check_where_clause(entity, where): - """ Checks the `where` clause of a query. A `where` clause is legal - if it only refers to fields found in the entity it's defined for. - Since only AND logical operations are supported server-side at the - moment, logical OR and NOT are illegal. - - Args: - entity (type): An Entity subclass type. - where (LogicalExpression or Comparison): The `where` clause of - query. - Return: - bool indicating if `where` is legal for `entity`. - """ - - def fields(where): - """ Yields all the fields in a `where` clause. """ - if isinstance(where, LogicalExpression): - for f in chain(fields(where.first), fields(where.second)): - yield f - elif isinstance(where, Comparison): - yield where.field - - # The `deleted` field is a special case, ignore it. - where_fields = [f for f in fields(where) if f != Entity.deleted] - invalid_fields = set(where_fields) - set(entity.fields()) - if invalid_fields: - raise InvalidAttributeError(entity, invalid_fields) - - if len(set(where_fields)) != len(where_fields): - raise InvalidQueryError( - "Where clause contains multiple comparisons for " - "the same field: %r." % where) - - if set(logical_ops(where)) not in (set(), {LogicalExpression.Op.AND}): - raise InvalidQueryError("Currently only AND logical ops are allowed in " - "the where clause of a query.") - - -def check_order_by_clause(entity, order_by): - """ Checks that the `order_by` clause field is a part of `entity`. - - Args: - entity (type): An Entity subclass type. - order_by ((field, ordering)): The ordering tuple consisting of - a field and sort ordering (ascending or descending). - Return: - bool indicating if `order_by` is legal for `entity`. - """ - if order_by is not None: - field, _ = order_by - if field not in entity.fields(): - raise InvalidAttributeError(entity, field) - - -def get_all(entity, where): - """ Constructs a query that fetches all items of the given type. The - resulting query is intended to be used for pagination, it contains - two python-string int-placeholders (%d) for 'skip' and 'first' - pagination parameters. - - Args: - entity (type): The object type being queried. - where (Comparison, LogicalExpression or None): The `where` clause - for filtering. - Return: - (str, dict) tuple that is the query string and parameters. - """ - check_where_clause(entity, where) - type_name = entity.type_name() - query = Query(utils.camel_case(type_name) + "s", entity, where, True) - return query.format_top("Get" + type_name + "s") - - -def relationship(source, relationship, where, order_by): - """ Constructs a query that fetches all items from a -to-many - relationship. To be used like: - >>> project = ... - >>> query_str, params = relationship(Project, "datasets", Dataset) - >>> datasets = PaginatedCollection( - client, query_str, params, ["project", "datasets"], - Dataset) - - The resulting query is intended to be used for pagination, it contains - two python-string int-placeholders (%d) for 'skip' and 'first' - pagination parameters. - - Args: - source (DbObject or type): If a `DbObject` then the source of the - relationship (the query originates from that particular object). - If `type`, then the source of the relationship is implicit, even - without the ID. Used for expanding from Organization. - relationship (Relationship): The relationship. - where (Comparison, LogicalExpression or None): The `where` clause - for filtering. - order_by (None or (Field, Field.Order): The `order_by` clause for - sorting results. - Return: - (str, dict) tuple that is the query string and parameters. - """ - check_where_clause(relationship.destination_type, where) - check_order_by_clause(relationship.destination_type, order_by) - to_many = relationship.relationship_type == Relationship.Type.ToMany - subquery = Query(relationship.graphql_name, relationship.destination_type, - where, to_many, order_by) - query_where = type(source).uid == source.uid if isinstance(source, Entity) \ - else None - query = Query(utils.camel_case(source.type_name()), subquery, query_where) - return query.format_top("Get" + source.type_name() + - utils.title_case(relationship.graphql_name)) - - -def create(entity, data): - """ Generates a query and parameters for creating a new DB object. - - Args: - entity (type): An Entity subtype indicating which kind of - DB object needs to be created. - data (dict): A dict that maps Fields and Relationships to values, new - object data. - Return: - (query_string, parameters) - """ - type_name = entity.type_name() - - def format_param_value(attribute, param): - if isinstance(attribute, Field): - return "%s: $%s" % (attribute.graphql_name, param) - else: - return "%s: {connect: {id: $%s}}" % (utils.camel_case( - attribute.graphql_name), param) - - # Convert data to params - params = { - field.graphql_name: (value, field) for field, value in data.items() - } - - query_str = """mutation Create%sPyApi%s{create%s(data: {%s}) {%s}} """ % ( - type_name, format_param_declaration(params), type_name, " ".join( - format_param_value(attribute, param) - for param, (_, attribute) in params.items()), - results_query_part(entity)) - - return query_str, {name: value for name, (value, _) in params.items()} - - -def update_relationship(a, b, relationship, update): - """ Updates the relationship in DB object `a` to connect or disconnect - DB object `b`. - - Args: - a (DbObject): The object being updated. - b (DbObject): Object on the other side of the relationship. - relationship (Relationship): The relationship from `a` to `b`. - update (str): The type of update. Must be either `connect` or - `disconnect`. - Return: - (query_string, query_parameters) - """ - to_one_disconnect = update == "disconnect" and \ - relationship.relationship_type == Relationship.Type.ToOne - - a_uid_param = utils.camel_case(type(a).type_name()) + "Id" - - if not to_one_disconnect: - b_uid_param = utils.camel_case(type(b).type_name()) + "Id" - param_declr = "($%s: ID!, $%s: ID!)" % (a_uid_param, b_uid_param) - b_query = "{id: $%s}" % b_uid_param - else: - param_declr = "($%s: ID!)" % a_uid_param - b_query = "true" - - query_str = """mutation %s%sAnd%sPyApi%s{update%s( - where: {id: $%s} data: {%s: {%s: %s}}) {id}} """ % ( - utils.title_case(update), type(a).type_name(), type(b).type_name(), - param_declr, utils.title_case(type(a).type_name()), a_uid_param, - relationship.graphql_name, update, b_query) - - if to_one_disconnect: - params = {a_uid_param: a.uid} - else: - params = {a_uid_param: a.uid, b_uid_param: b.uid} - - return query_str, params - - -def update_fields(db_object, values): - """ Creates a query that updates `db_object` fields with the - given values. - - Args: - db_object (DbObject): The DB object being updated. - values (dict): Maps Fields to new values. All Fields - must be legit fields in `db_object`. - Return: - (query_string, query_parameters) - """ - type_name = db_object.type_name() - id_param = "%sId" % type_name - values_str = " ".join("%s: $%s" % (field.graphql_name, field.graphql_name) - for field, _ in values.items()) - params = { - field.graphql_name: (value, field) for field, value in values.items() - } - params[id_param] = (db_object.uid, Entity.uid) - - query_str = """mutation update%sPyApi%s{update%s( - where: {id: $%s} data: {%s}) {%s}} """ % ( - utils.title_case(type_name), format_param_declaration(params), - type_name, id_param, values_str, results_query_part(type(db_object))) - - return query_str, {name: value for name, (value, _) in params.items()} - - -def delete(db_object): - """ Generates a query that deletes the given `db_object` from the DB. - - Args: - db_object (DbObject): The DB object being deleted. - """ - id_param = "%sId" % db_object.type_name() - query_str = """mutation delete%sPyApi%s{update%s( - where: {id: $%s} data: {deleted: true}) {id}} """ % ( - db_object.type_name(), "($%s: ID!)" % id_param, db_object.type_name(), - id_param) - - return query_str, {id_param: db_object.uid} - - -def bulk_delete(db_objects, use_where_clause): - """ Generates a query that bulk-deletes the given `db_objects` from the - DB. - - Args: - db_objects (list): A list of DB objects of the same type. - use_where_clause (bool): If the object IDs should be passed to the - mutation in a `where` clause or directly as a mutation value. - """ - type_name = db_objects[0].type_name() - if use_where_clause: - query_str = "mutation delete%ssPyApi{delete%ss(where: {%sIds: [%s]}){id}}" - else: - query_str = "mutation delete%ssPyApi{delete%ss(%sIds: [%s]){id}}" - query_str = query_str % ( - utils.title_case(type_name), utils.title_case(type_name), - utils.camel_case(type_name), ", ".join( - '"%s"' % db_object.uid for db_object in db_objects)) - return query_str, {} - - -def where_as_dict(entity, node: Comparison) -> Dict[str, Any]: - check_where_clause(entity, node) - """ Helper that constructs a where clause from a Comparison node. """ - COMPARISON_TO_SUFFIX = { - Comparison.Op.EQ: "", - Comparison.Op.NE: "_not", - Comparison.Op.LT: "_lt", - Comparison.Op.GT: "_gt", - Comparison.Op.LE: "_lte", - Comparison.Op.GE: "_gte", - } - - key = f"{node.field.graphql_name}{COMPARISON_TO_SUFFIX[node.op]}" - return {key: node.value} - - -def order_by_as_string(entity, node: tuple) -> str: - check_order_by_clause(entity, node) - return f"{node[0].graphql_name}_{node[1].name.upper()}" ----- -labelbox/orm/__init__.py - ----- -labelbox/orm/model.py -from enum import Enum, auto -from typing import Dict, List, Union, Any, Type, TYPE_CHECKING - -import labelbox -from labelbox import utils -from labelbox.exceptions import InvalidAttributeError -from labelbox.orm.comparison import Comparison -""" Defines Field, Relationship and Entity. These classes are building -blocks for defining the Labelbox schema, DB object operations and -queries. """ - - -class Field: - """ Represents a field in a database table. A Field has a name, a type - (corresponds to server-side GraphQL type) and a server-side name. The - server-side name is most often just a camelCase version of the client-side - snake_case name. - - Supports comparison operators which return a `labelbox.comparison.Comparison` - object. For example: - >>> class Project: - >>> name = Field.String("name") - >>> - >>> comparison = Project.name == "MyProject" - - These `Comparison` objects can then be used for filtering: - >>> project = client.get_projects(comparison) - - Also exposes the ordering property used for sorting: - >>> labels = project.labels(order_by=Label.label.asc) - - Attributes: - field_type (Field.Type): The type of the field. - name (str): name that the attribute has in client-side Python objects - graphql_name (str): name that the attribute has in queries (and in - server-side database definition). - result_subquery (str): graphql query result payload for a field. - """ - - class Type(Enum): - Int = auto() - Float = auto() - String = auto() - Boolean = auto() - ID = auto() - DateTime = auto() - Json = auto() - - class EnumType: - - def __init__(self, enum_cls: type): - self.enum_cls = enum_cls - - @property - def name(self): - return self.enum_cls.__name__ - - class ListType: - """ Represents Field that is a list of some object. - Args: - list_cls (type): Type of object that list is made of. - graphql_type (str): Inner object's graphql type. - By default, the list_cls's name is used as the graphql type. - """ - - def __init__(self, list_cls: type, graphql_type=None): - self.list_cls = list_cls - if graphql_type is None: - self.graphql_type = self.list_cls.__name__ - else: - self.graphql_type = graphql_type - - @property - def name(self): - return f"[{self.graphql_type}]" - - class Order(Enum): - """ Type of sort ordering. """ - Asc = auto() - Desc = auto() - - @staticmethod - def Int(*args): - return Field(Field.Type.Int, *args) - - @staticmethod - def Float(*args): - return Field(Field.Type.Float, *args) - - @staticmethod - def String(*args): - return Field(Field.Type.String, *args) - - @staticmethod - def Boolean(*args): - return Field(Field.Type.Boolean, *args) - - @staticmethod - def ID(*args): - return Field(Field.Type.ID, *args) - - @staticmethod - def DateTime(*args): - return Field(Field.Type.DateTime, *args) - - @staticmethod - def Enum(enum_cls: type, *args): - return Field(Field.EnumType(enum_cls), *args) - - @staticmethod - def Json(*args): - return Field(Field.Type.Json, *args) - - @staticmethod - def List(list_cls: type, graphql_type=None, **kwargs): - return Field(Field.ListType(list_cls, graphql_type), **kwargs) - - def __init__(self, - field_type: Union[Type, EnumType, ListType], - name, - graphql_name=None, - result_subquery=None): - """ Field init. - Args: - field_type (Field.Type): The type of the field. - name (str): client-side Python attribute name of a database - object. - graphql_name (str): query and server-side name of a database object. - If None, it is constructed from the client-side name by converting - snake_case (Python convention) into camelCase (GraphQL convention). - result_subquery (str): graphql query result payload for a field. - """ - self.field_type = field_type - self.name = name - if graphql_name is None: - graphql_name = utils.camel_case(name) - self.graphql_name = graphql_name - self.result_subquery = result_subquery - - @property - def asc(self): - """ Property that resolves to tuple (Field, Field.Order). - Used for easy definition of sort ordering: - >>> projects_ordered = client.get_projects(order_by=Project.name.asc) - """ - return (self, Field.Order.Asc) - - @property - def desc(self): - """ Property that resolves to tuple (Field, Field.Order). - Used for easy definition of sort ordering: - >>> projects_ordered = client.get_projects(order_by=Project.name.desc) - """ - return (self, Field.Order.Desc) - - def __eq__(self, other): - """ Equality of Fields has two meanings. If comparing to a Field object, - then a boolean indicator if the fields are identical is returned. If - comparing to any other type, a Comparison object is created. - """ - if isinstance(other, Field): - return self is other - - return Comparison.Op.EQ(self, other) - - def __ne__(self, other): - """ Equality of Fields has two meanings. If comparing to a Field object, - then a boolean indicator if the fields are identical is returned. If - comparing to any other type, a Comparison object is created. - """ - if isinstance(other, Field): - return self is not other - - return Comparison.Op.NE(self, other) - - def __hash__(self): - # Hash is implemeted as ID, because for each DB field exactly one - # Field object should exist in the Python API. - return id(self) - - def __lt__(self, other): - return Comparison.Op.LT(self, other) - - def __gt__(self, other): - return Comparison.Op.GT(self, other) - - def __le__(self, other): - return Comparison.Op.LE(self, other) - - def __ge__(self, other): - return Comparison.Op.GE(self, other) - - def __str__(self): - return self.name - - def __repr__(self): - return "" % self.name - - -class Relationship: - """ Represents a relationship in a database table. - - Attributes: - relationship_type (Relationship.Type): Indicator if to-one or to-many - destination_type_name (str): Name of the Entity subtype that's on - the other side of the relationship. str is used instead of the - type object itself because that type might not be declared at - the point of a `Relationship` object initialization. - filter_deleted (bool): Indicator if the a `deleted=false` filtering - clause should be added to the query when fetching relationship - objects. - name (str): Name of the relationship in the snake_case format. - graphql_name (str): Name of the relationships server-side. Most often - (not always) just a camelCase version of `name`. - cache (bool) : Whether or not to cache the relationship values. - Useful for objects that aren't directly queryable from the api (relationship query builder won't work) - Also useful for expensive ToOne relationships - deprecation_warning (string) optional message to display when RelationshipManager is called - - """ - - class Type(Enum): - ToOne = auto() - ToMany = auto() - - @staticmethod - def ToOne(*args, **kwargs): - return Relationship(Relationship.Type.ToOne, *args, **kwargs) - - @staticmethod - def ToMany(*args, **kwargs): - return Relationship(Relationship.Type.ToMany, *args, **kwargs) - - def __init__(self, - relationship_type, - destination_type_name, - filter_deleted=True, - name=None, - graphql_name=None, - cache=False, - deprecation_warning=None): - self.relationship_type = relationship_type - self.destination_type_name = destination_type_name - self.filter_deleted = filter_deleted - self.cache = cache - self.deprecation_warning = deprecation_warning - - if name is None: - name = utils.snake_case(destination_type_name) + ( - "s" if relationship_type == Relationship.Type.ToMany else "") - self.name = name - - if graphql_name is None: - graphql_name = utils.camel_case(name) - self.graphql_name = graphql_name - - @property - def destination_type(self): - return getattr(Entity, self.destination_type_name) - - def __str__(self): - return self.name - - def __repr__(self): - return "" % self.name - - -class EntityMeta(type): - """ Entity metaclass. Registers Entity subclasses as attributes - of the Entity class object so they can be referenced for example like: - Entity.Project. - """ - # Maps Entity name to Relationships for all currently defined Entities - relationship_mappings: Dict[str, List[Relationship]] = {} - - def __setattr__(self, key: Any, value: Any): - super().__setattr__(key, value) - - def __init__(cls, clsname, superclasses, attributedict): - super().__init__(clsname, superclasses, attributedict) - cls.validate_cached_relationships() - if clsname != "Entity": - setattr(Entity, clsname, cls) - EntityMeta.relationship_mappings[utils.snake_case( - cls.__name__)] = cls.relationships() - - @staticmethod - def raise_for_nested_cache(first: str, middle: str, last: List[str]): - raise TypeError( - "Cannot cache a relationship to an Entity with its own cached relationship(s). " - f"`{first}` caches `{middle}` which caches `{last}`") - - @staticmethod - def cached_entities(entity_name: str): - """ - Return all cached entites for a given Entity name - """ - cached_entities = EntityMeta.relationship_mappings.get(entity_name, []) - return { - entity.name: entity for entity in cached_entities if entity.cache - } - - def validate_cached_relationships(cls): - """ - Graphql doesn't allow for infinite nesting in queries. - This function checks that cached relationships result in valid queries. - * It does this by making sure that a cached relationship do not - reference any entity with its own cached relationships. - - This check is performed by looking to see if this entity caches - any entities that have their own cached fields. If this entity - that we are checking has any cached fields then we also check - all currently defined entities to see if they cache this entity. - - A two way check is necessary because checks are performed as classes are being defined. - As opposed to after all objects have been created. - """ - # All cached relationships - cached_rels = [r for r in cls.relationships() if r.cache] - - # Check if any cached entities have their own cached fields - for rel in cached_rels: - nested = cls.cached_entities(rel.name) - if nested: - cls.raise_for_nested_cache(utils.snake_case(cls.__name__), - rel.name, list(nested.keys())) - - # If the current Entity (cls) has any cached relationships (cached_rels) - # then no other defined Entity (entities in EntityMeta.relationship_mappings) can cache this Entity. - if cached_rels: - # For all currently defined Entities - for entity_name in EntityMeta.relationship_mappings: - # Get all cached ToOne relationships - rels = cls.cached_entities(entity_name) - # Check if the current Entity (cls) is referenced by the Entity with `entity_name` - rel = rels.get(utils.snake_case(cls.__name__)) - # If rel exists and is cached then raise an exception - # This means `entity_name` caches `cls` which cached items in `cached_rels` - if rel and rel.cache: - cls.raise_for_nested_cache( - utils.snake_case(entity_name), - utils.snake_case(cls.__name__), - [entity.name for entity in cached_rels]) - - -class Entity(metaclass=EntityMeta): - """ An entity that contains fields and relationships. Base class - for DbObject (which is base class for concrete schema classes). """ - - # Every Entity has an "id" and a "deleted" field - # Name the "id" field "uid" in Python to avoid conflict with keyword. - uid = Field.ID("uid", "id") - - # Some Labelbox objects have a "deleted" attribute for soft deletions. - # It's declared in Entity so it can be filtered out in class methods - # suchs as `fields()`. - deleted = Field.Boolean("deleted") - - if TYPE_CHECKING: - DataRow: Type[labelbox.DataRow] - Webhook: Type[labelbox.Webhook] - Task: Type[labelbox.Task] - AssetAttachment: Type[labelbox.AssetAttachment] - ModelRun: Type[labelbox.ModelRun] - Review: Type[labelbox.Review] - User: Type[labelbox.User] - LabelingFrontend: Type[labelbox.LabelingFrontend] - BulkImportRequest: Type[labelbox.BulkImportRequest] - Benchmark: Type[labelbox.Benchmark] - IAMIntegration: Type[labelbox.IAMIntegration] - LabelingFrontendOptions: Type[labelbox.LabelingFrontendOptions] - Label: Type[labelbox.Label] - MEAPredictionImport: Type[labelbox.MEAPredictionImport] - MALPredictionImport: Type[labelbox.MALPredictionImport] - Invite: Type[labelbox.Invite] - InviteLimit: Type[labelbox.InviteLimit] - ProjectRole: Type[labelbox.ProjectRole] - Project: Type[labelbox.Project] - Batch: Type[labelbox.Batch] - CatalogSlice: Type[labelbox.CatalogSlice] - ModelSlice: Type[labelbox.ModelSlice] - TaskQueue: Type[labelbox.TaskQueue] - - @classmethod - def _attributes_of_type(cls, attr_type): - """ Yields all the attributes in `cls` of the given `attr_type`. """ - for attr_name in dir(cls): - attr = getattr(cls, attr_name) - if isinstance(attr, attr_type): - yield attr - - @classmethod - def fields(cls): - """ Returns a generator that yields all the Fields declared in a - concrete subclass. - """ - for attr in cls._attributes_of_type(Field): - if attr != Entity.deleted: - yield attr - - @classmethod - def relationships(cls): - """ Returns a generator that yields all the Relationships declared in - a concrete subclass. - """ - return cls._attributes_of_type(Relationship) - - @classmethod - def field(cls, field_name): - """ Returns a Field object for the given name. - Args: - field_name (str): Field name, Python (snake-case) convention. - Return: - Field object - Raises: - InvalidAttributeError: in case this DB object type does not - contain a field with the given name. - """ - field_obj = getattr(cls, field_name, None) - if not isinstance(field_obj, Field): - raise InvalidAttributeError(cls, field_name) - return field_obj - - @classmethod - def attribute(cls, attribute_name): - """ Returns a Field or a Relationship object for the given name. - Args: - attribute_name (str): Field or Relationship name, Python - (snake-case) convention. - Return: - Field or Relationship object - Raises: - InvalidAttributeError: in case this DB object type does not - contain an attribute with the given name. - """ - attribute_object = getattr(cls, attribute_name, None) - if not isinstance(attribute_object, (Field, Relationship)): - raise InvalidAttributeError(cls, attribute_name) - return attribute_object - - @classmethod - def type_name(cls): - """ Returns this DB object type name in TitleCase. For example: - Project, DataRow, ... - """ - return cls.__name__.split(".")[-1] - ----- -labelbox/orm/comparison.py -from enum import Enum, auto -""" Classes for defining the client-side comparison operations used -for filtering data in fetches. Intended for use by library internals -and not by the end user. -""" - - -class LogicalExpressionComponent: - """ Implements bitwise logical operator methods (&, | and ~) so they - return a LogicalExpression object containing this - LogicalExpressionComponent. - """ - - def __and__(self, other): - if not isinstance(other, (LogicalExpression, Comparison)): - return NotImplemented - return LogicalExpression.Op.AND(self, other) - - def __or__(self, other): - if not isinstance(other, (LogicalExpression, Comparison)): - return NotImplemented - return LogicalExpression.Op.OR(self, other) - - def __invert__(self): - return LogicalExpression.Op.NOT(self) - - -class LogicalExpression(LogicalExpressionComponent): - """ A unary (NOT) or binary (AND, OR) logical expression between - Comparison or LogicalExpression objects. """ - - class Op(Enum): - """ Type of logical operation. """ - AND = auto() - OR = auto() - NOT = auto() - - def __call__(self, first, second=None): - """ Forwards to LogicalExpression constructor, passing `self` - as the `op` argument. """ - return LogicalExpression(self, first, second) - - def __init__(self, op, first, second=None): - """ LogicalExpression constructor. - - Args: - op (LogicalExpression.Op): The type of logical operation. - first (LogicalExpression or Comparison): First operand. - second (LogicalExpression or Comparison): Second operand. - """ - self.op = op - self.first = first - self.second = second - - def __eq__(self, other): - return self.op == other.op and ( - (self.first == other.first and self.second == other.second) or - (self.first == other.second and self.second == other.first)) - - def __hash__(self): - return hash( - self.op) + 2833 * hash(self.first) + 2837 * hash(self.second) - - def __repr__(self): - return "%r %s %r" % (self.first, self.op.name, self.second) - - def __str__(self): - return "%s %s %s" % (self.first, self.op.name, self.second) - - -class Comparison(LogicalExpressionComponent): - """ A comparison between a database value (represented by a - `labelbox.schema.Field` object) and a constant value. """ - - class Op(Enum): - """ Type of the comparison operation. """ - EQ = auto() - NE = auto() - LT = auto() - GT = auto() - LE = auto() - GE = auto() - - def __call__(self, *args): - """ Forwards to Comparison constructor, passing `self` - as the `op` argument. """ - return Comparison(self, *args) - - def __init__(self, op, field, value): - """ Comparison constructor. - - Args: - op (Comparison.Op): The type of comparison. - field (labelbox.schema.Field): Field being compared. - value (any): Value to which the DB field is compared. - """ - self.op = op - self.field = field - self.value = value - - def __eq__(self, other): - return self.op == other.op and \ - self.field == other.field and self.value == other.value - - def __hash__(self): - return hash(self.op) + 2861 * hash(self.field) + 2927 * hash(self.value) - - def __repr__(self): - return "%r %s %r" % (self.field, self.op.name, self.value) - - def __str__(self): - return "%s %s %s" % (self.field, self.op.name, self.value) - ----- -labelbox/orm/db_object.py -from datetime import datetime, timezone -from functools import wraps -import logging -import json - -from labelbox import utils -from labelbox.exceptions import InvalidQueryError, InvalidAttributeError -from labelbox.orm import query -from labelbox.orm.model import Field, Relationship, Entity -from labelbox.pagination import PaginatedCollection - -logger = logging.getLogger(__name__) - - -class DbObject(Entity): - """ A client-side representation of a database object (row). Intended as - base class for classes representing concrete database types (for example - a Project). Exposes support functionalities so that the concrete subclass - definition be as simple and DRY as possible. It should come down to just - listing Fields of that particular database type. For example: - - >>> class Project(DbObject): - >>> name = Field.String("name") - >>> labels = Relationship.ToMany("Label", True) - - This defines a `Project` class that has class attributes which are - `Field`s and `Relationship`s. An instance of `Project` represents - a database record. It has the same attributes as the `Project` class, - but they are now attribute values of that record: - - >>> project = client.create_project(name="MyProject") - >>> project.name - "MyProject" - """ - - def __init__(self, client, field_values): - """ Constructor of a database object. Generally it should only be used - by library internals and not by the end user. - - Args: - client (labelbox.Client): the client used for fetching data from DB. - field_values (dict): Data obtained from the DB. Maps database object - fields (their graphql_name version) to values. - """ - self.client = client - self._set_field_values(field_values) - for relationship in self.relationships(): - value = field_values.get(utils.camel_case(relationship.name)) - if relationship.cache and value is None: - raise KeyError( - f"Expected field values for {relationship.name}") - setattr(self, relationship.name, - RelationshipManager(self, relationship, value)) - - def _set_field_values(self, field_values): - """ Sets field values on this object. Ensures proper value conversions. - Args: - field_values (dict): Maps field names (GraphQL variant, snakeCase) - to values. *Must* contain all field values for this object's - DB type. - """ - for field in self.fields(): - value = field_values[field.graphql_name] - if field.field_type == Field.Type.DateTime and value is not None: - try: - value = datetime.strptime(value, "%Y-%m-%dT%H:%M:%S.%fZ") - value = value.replace(tzinfo=timezone.utc) - except ValueError: - logger.warning( - "Failed to convert value '%s' to datetime for " - "field %s", value, field) - elif isinstance(field.field_type, Field.EnumType): - value = field.field_type.enum_cls(value) - elif isinstance(field.field_type, Field.ListType): - if field.field_type.list_cls.__name__ == "DataRowMetadataField": - mdo = self.client.get_data_row_metadata_ontology() - try: - value = mdo.parse_metadata_fields(value) - except ValueError: - logger.warning( - "Failed to convert value '%s' to metadata for field %s", - value, field) - setattr(self, field.name, value) - - def __repr__(self): - type_name = self.type_name() - if "uid" in self.__dict__: - return "<%s ID: %s>" % (type_name, self.uid) - else: - return "<%s>" % type_name - - def __str__(self): - attribute_values = { - field.name: getattr(self, field.name) for field in self.fields() - } - return "<%s %s>" % (self.type_name().split(".")[-1], - json.dumps(attribute_values, indent=4, default=str)) - - def __eq__(self, other): - return (isinstance(other, DbObject) and - self.type_name() == other.type_name() and self.uid == other.uid) - - def __hash__(self): - return 7541 * hash(self.type_name()) + hash(self.uid) - - -class RelationshipManager: - """ Manages relationships (object fetching and updates) for a `DbObject` - instance. There is one RelationshipManager for each relationship in - each `DbObject` instance. - """ - - def __init__(self, source, relationship, value=None): - """Args: - source (DbObject subclass instance): The object that's the source - of the relationship. - relationship (labelbox.schema.Relationship): The relationship - schema descriptor object. - """ - self.source = source - self.relationship = relationship - self.supports_filtering = True - self.supports_sorting = True - self.filter_on_id = True - self.value = value - - def __call__(self, *args, **kwargs): - """ Forwards the call to either `_to_many` or `_to_one` methods, - depending on relationship type. """ - - if self.relationship.deprecation_warning: - logger.warning(self.relationship.deprecation_warning) - - if self.relationship.relationship_type == Relationship.Type.ToMany: - return self._to_many(*args, **kwargs) - else: - return self._to_one(*args, **kwargs) - - def _to_many(self, where=None, order_by=None): - """ Returns an iterable over the destination relationship objects. - Args: - where (None, Comparison or LogicalExpression): Filtering clause. - order_by (None or (Field, Field.Order)): Ordering clause. - Return: - iterable over destination DbObject instances. - """ - rel = self.relationship - if where is not None and not self.supports_filtering: - raise InvalidQueryError( - "Relationship %s.%s doesn't support filtering" % - (self.source.type_name(), rel.name)) - if order_by is not None and not self.supports_sorting: - raise InvalidQueryError( - "Relationship %s.%s doesn't support sorting" % - (self.source.type_name(), rel.name)) - - if rel.filter_deleted: - not_deleted = rel.destination_type.deleted == False - where = not_deleted if where is None else where & not_deleted - - query_string, params = query.relationship( - self.source if self.filter_on_id else type(self.source), rel, where, - order_by) - return PaginatedCollection( - self.source.client, query_string, params, - [utils.camel_case(self.source.type_name()), rel.graphql_name], - rel.destination_type) - - def _to_one(self): - """ Returns the relationship destination object. """ - rel = self.relationship - - if self.value: - return rel.destination_type(self.source.client, self.value) - - query_string, params = query.relationship(self.source, rel, None, None) - result = self.source.client.execute(query_string, params) - result = result and result.get( - utils.camel_case(type(self.source).type_name())) - result = result and result.get(rel.graphql_name) - if result is None: - return None - - return rel.destination_type(self.source.client, result) - - def connect(self, other): - """ Connects source object of this manager to the `other` object. """ - query_string, params = query.update_relationship( - self.source, other, self.relationship, "connect") - self.source.client.execute(query_string, params) - - def disconnect(self, other): - """ Disconnects source object of this manager from the `other` object. """ - query_string, params = query.update_relationship( - self.source, other, self.relationship, "disconnect") - self.source.client.execute(query_string, params) - - -class Updateable: - - def update(self, **kwargs): - """ Updates this DB object with new values. Values should be - passed as key-value arguments with field names as keys: - >>> db_object.update(name="New name", title="A title") - - Kwargs: - Key-value arguments defining which fields should be updated - for which values. Keys must be field names in this DB object's - type. - Raise: - InvalidAttributeError: if there exists a key in `kwargs` - that's not a field in this object type. - """ - values = {self.field(name): value for name, value in kwargs.items()} - invalid_fields = set(values) - set(self.fields()) - if invalid_fields: - raise InvalidAttributeError(type(self), invalid_fields) - - query_string, params = query.update_fields(self, values) - res = self.client.execute(query_string, params) - res = res["update%s" % utils.title_case(self.type_name())] - self._set_field_values(res) - - -class Deletable: - """ Implements deletion for objects that have a `deleted` attribute. """ - - def delete(self): - """ Deletes this DB object from the DB (server side). After - a call to this you should not use this DB object anymore. - """ - query_string, params = query.delete(self) - self.client.execute(query_string, params) - - -class BulkDeletable: - """ Implements deletion for objects that have a custom, bulk deletion - mutation (accepts a list of IDs of objects to be deleted). - - A subclass must override the `bulk_delete` static method so it - accepts only the `objects` argument and then invoke BulkDeletable.bulk_delete - with the appropriate `use_where_clause` argument for that particular - type. - """ - - @staticmethod - def _bulk_delete(objects, use_where_clause): - """ - Args: - objects (list): Objects to delete. All objects must be of the same - DbObject subtype. - use_where_clause (bool): If the GraphQL query object IDs should be - passed under `where` or directly. Necessary because the bulk - deletion mutation is implemented differently for different - object types (DataRow.bulkDelete vs Label.bulkDelete). - """ - types = {type(o) for o in objects} - if len(types) != 1: - raise InvalidQueryError( - "Can't bulk-delete objects of different types: %r" % types) - - query_str, params = query.bulk_delete(objects, use_where_clause) - objects[0].client.execute(query_str, params) - - def delete(self): - """ Deletes this DB object from the DB (server side). After - a call to this you should not use this DB object anymore. - """ - type(self).bulk_delete([self]) - - -def experimental(fn): - """Decorator used to mark functions that are experimental. This means that - the interface could change. This decorator will check if the client has - experimental features enabled. If not, it will raise a runtime error. - """ - is_static = isinstance(fn, staticmethod) - - @wraps(fn) - def wrapper(*args, **kwargs): - client = None - wrapped_fn = None - if is_static: - # assumes that the first argument is the client, needs modification if that changes - if len(args) >= 1: - client = args[0] - elif "client" in kwargs: - client = kwargs["client"] - else: - raise ValueError( - f"Static method {fn.__name__} must have a client passed in as the first " - f"argument or as a keyword argument.") - wrapped_fn = fn.__func__ - else: - client = args[0].client - wrapped_fn = fn - - if not client.enable_experimental: - raise RuntimeError( - f"This function {fn.__name__} relies on a experimental feature in the api. " - f"This means that the interface could change. " - f"Set `enable_experimental=True` in the client to enable use of " - f"experimental functions.") - return wrapped_fn(*args, **kwargs) - - return wrapper - ----- -labelbox/data/mixins.py -from typing import Optional, List - -from labelbox import pydantic_compat - -from labelbox.exceptions import ConfidenceNotSupportedException, CustomMetricsNotSupportedException - - -class ConfidenceMixin(pydantic_compat.BaseModel): - confidence: Optional[float] = None - - @pydantic_compat.validator("confidence") - def confidence_valid_float(cls, value): - if value is None: - return value - if not isinstance(value, (int, float)) or not 0 <= value <= 1: - raise ValueError("must be a number within [0,1] range") - return value - - def dict(self, *args, **kwargs): - res = super().dict(*args, **kwargs) - if "confidence" in res and res["confidence"] is None: - res.pop("confidence") - return res - - -class ConfidenceNotSupportedMixin: - - def __new__(cls, *args, **kwargs): - if "confidence" in kwargs: - raise ConfidenceNotSupportedException( - "Confidence is not supported for this annotation type yet") - return super().__new__(cls) - - -class CustomMetric(pydantic_compat.BaseModel): - name: str - value: float - - @pydantic_compat.validator("name") - def confidence_valid_float(cls, value): - if not isinstance(value, str): - raise ValueError("Name must be a string") - return value - - @pydantic_compat.validator("value") - def value_valid_float(cls, value): - if not isinstance(value, (int, float)): - raise ValueError("Value must be a number") - return value - - -class CustomMetricsMixin(pydantic_compat.BaseModel): - custom_metrics: Optional[List[CustomMetric]] = None - - def dict(self, *args, **kwargs): - res = super().dict(*args, **kwargs) - - if "customMetrics" in res and res["customMetrics"] is None: - res.pop("customMetrics") - - if "custom_metrics" in res and res["custom_metrics"] is None: - res.pop("custom_metrics") - - return res - - -class CustomMetricsNotSupportedMixin: - - def __new__(cls, *args, **kwargs): - if "custom_metrics" in kwargs: - raise CustomMetricsNotSupportedException( - "Custom metrics is not supported for this annotation type yet") - return super().__new__(cls) - ----- -labelbox/data/__init__.py - ----- -labelbox/data/generator.py -import logging -import threading -from queue import Queue -from typing import Any, Iterable -import threading - -logger = logging.getLogger(__name__) - - -class ThreadSafeGen: - """ - Wraps generators to make them thread safe - """ - - def __init__(self, iterable: Iterable[Any]): - """ - - """ - self.iterable = iterable - self.lock = threading.Lock() - - def __iter__(self): - return self - - def __next__(self): - with self.lock: - return next(self.iterable) - - -class PrefetchGenerator: - """ - Applys functions asynchronously to the output of a generator. - Useful for modifying the generator results based on data from a network - """ - - def __init__(self, data: Iterable[Any], prefetch_limit=20, num_executors=1): - if isinstance(data, (list, tuple)): - self._data = (r for r in data) - else: - self._data = data - - self.queue = Queue(prefetch_limit) - self.completed_threads = 0 - # Can only iterate over once it the queue.get hangs forever. - self.multithread = num_executors > 1 - self.done = False - - if self.multithread: - self._data = ThreadSafeGen(self._data) - self.num_executors = num_executors - self.threads = [ - threading.Thread(target=self.fill_queue) - for _ in range(num_executors) - ] - for thread in self.threads: - thread.daemon = True - thread.start() - else: - self._data = iter(self._data) - - def _process(self, value) -> Any: - raise NotImplementedError("Abstract method needs to be implemented") - - def fill_queue(self): - try: - for value in self._data: - value = self._process(value) - if value is None: - raise ValueError("Unexpected None") - self.queue.put(value) - except Exception as e: - self.queue.put( - ValueError(f"Unexpected exception while filling queue: {e}")) - finally: - self.queue.put(None) - - def __iter__(self): - return self - - def __next__(self) -> Any: - if self.done: - raise StopIteration - - if self.multithread: - value = self.queue.get() - - while value is None: - self.completed_threads += 1 - if self.completed_threads == self.num_executors: - self.done = True - for thread in self.threads: - thread.join() - raise StopIteration - value = self.queue.get() - if isinstance(value, Exception): - raise value - else: - value = self._process(next(self._data)) - return value - ----- -labelbox/data/ontology.py -from typing import Dict, List, Tuple, Union - -from labelbox.schema import ontology -from .annotation_types import (Text, Dropdown, Checklist, Radio, - ClassificationAnnotation, ObjectAnnotation, Mask, - Point, Line, Polygon, Rectangle, TextEntity) - - -def get_feature_schema_lookup( - ontology_builder: ontology.OntologyBuilder -) -> Tuple[Dict[str, str], Dict[str, str]]: - tool_lookup = {} - classification_lookup = {} - - def flatten_classification(classifications): - for classification in classifications: - if classification.feature_schema_id is None: - raise ValueError( - f"feature_schema_id cannot be None for classification `{classification.name}`." - ) - if isinstance(classification, ontology.Classification): - classification_lookup[ - classification.name] = classification.feature_schema_id - elif isinstance(classification, ontology.Option): - classification_lookup[ - classification.value] = classification.feature_schema_id - else: - raise TypeError( - f"Unexpected type found in ontology. `{type(classification)}`" - ) - flatten_classification(classification.options) - - for tool in ontology_builder.tools: - if tool.feature_schema_id is None: - raise ValueError( - f"feature_schema_id cannot be None for tool `{tool.name}`.") - tool_lookup[tool.name] = tool.feature_schema_id - flatten_classification(tool.classifications) - flatten_classification(ontology_builder.classifications) - return tool_lookup, classification_lookup - - -def _get_options(annotation: ClassificationAnnotation, - existing_options: List[ontology.Option]): - if isinstance(annotation.value, Radio): - answers = [annotation.value.answer] - elif isinstance(annotation.value, Text): - return existing_options - elif isinstance(annotation.value, (Checklist, Dropdown)): - answers = annotation.value.answer - else: - raise TypeError( - f"Expected one of Radio, Text, Checklist, Dropdown. Found {type(annotation.value)}" - ) - - option_names = {option.value for option in existing_options} - for answer in answers: - if answer.name not in option_names: - existing_options.append(ontology.Option(value=answer.name)) - option_names.add(answer.name) - return existing_options - - -def get_classifications( - annotations: List[ClassificationAnnotation], - existing_classifications: List[ontology.Classification] -) -> List[ontology.Classification]: - existing_classifications = { - classification.name: classification - for classification in existing_classifications - } - for annotation in annotations: - # If the classification exists then we just want to add options to it - classification_feature = existing_classifications.get(annotation.name) - if classification_feature: - classification_feature.options = _get_options( - annotation, classification_feature.options) - elif annotation.name not in existing_classifications: - existing_classifications[annotation.name] = ontology.Classification( - class_type=classification_mapping(annotation), - name=annotation.name, - options=_get_options(annotation, [])) - return list(existing_classifications.values()) - - -def get_tools( - annotations: List[ObjectAnnotation], - existing_tools: List[ontology.Classification]) -> List[ontology.Tool]: - existing_tools = {tool.name: tool for tool in existing_tools} - for annotation in annotations: - if annotation.name in existing_tools: - # We just want to update classifications - existing_tools[ - annotation.name].classifications = get_classifications( - annotation.classifications, - existing_tools[annotation.name].classifications) - else: - existing_tools[annotation.name] = ontology.Tool( - tool=tool_mapping(annotation), - name=annotation.name, - classifications=get_classifications(annotation.classifications, - [])) - return list(existing_tools.values()) - - -def tool_mapping( - annotation) -> Union[Mask, Polygon, Point, Rectangle, Line, TextEntity]: - tool_types = ontology.Tool.Type - mapping = { - Mask: tool_types.SEGMENTATION, - Polygon: tool_types.POLYGON, - Point: tool_types.POINT, - Rectangle: tool_types.BBOX, - Line: tool_types.LINE, - TextEntity: tool_types.NER, - } - result = mapping.get(type(annotation.value)) - if result is None: - raise TypeError( - f"Unexpected type found. {type(annotation.value)}. Expected one of {list(mapping.keys())}" - ) - return result - - -def classification_mapping( - annotation) -> Union[Text, Checklist, Radio, Dropdown]: - classification_types = ontology.Classification.Type - mapping = { - Text: classification_types.TEXT, - Checklist: classification_types.CHECKLIST, - Radio: classification_types.RADIO, - Dropdown: classification_types.DROPDOWN - } - result = mapping.get(type(annotation.value)) - if result is None: - raise TypeError( - f"Unexpected type found. {type(annotation.value)}. Expected one of {list(mapping.keys())}" - ) - return result - ----- -labelbox/data/metrics/__init__.py -from .confusion_matrix import confusion_matrix_metric, feature_confusion_matrix_metric -from .iou import miou_metric, feature_miou_metric - ----- -labelbox/data/metrics/group.py -""" -Tools for grouping features and labels so that we can compute metrics on the individual groups -""" -from collections import defaultdict -from typing import Dict, List, Tuple, Union - -from labelbox.data.annotation_types.annotation import ClassificationAnnotation -from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnswer, Radio, Text -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal - -from ..annotation_types.feature import FeatureSchema -from ..annotation_types import ObjectAnnotation, ClassificationAnnotation, Label - - -def get_identifying_key( - features_a: List[FeatureSchema], features_b: List[FeatureSchema] -) -> Union[Literal['name'], Literal['feature_schema_id']]: - """ - Checks to make sure that features in both sets contain the same type of identifying keys. - This can either be the feature name or feature schema id. - - Args: - features_a : List of FeatureSchemas (usually ObjectAnnotations or ClassificationAnnotations) - features_b : List of FeatureSchemas (usually ObjectAnnotations or ClassificationAnnotations) - Returns: - The field name that is present in both feature lists. - """ - - all_schema_ids_defined_pred, all_names_defined_pred = all_have_key( - features_a) - if (not all_schema_ids_defined_pred and not all_names_defined_pred): - raise ValueError("All data must have feature_schema_ids or names set") - - all_schema_ids_defined_gt, all_names_defined_gt = all_have_key(features_b) - - # Prefer name becuse the user will be able to know what it means - # Schema id incase that doesn't exist. - if (all_names_defined_pred and all_names_defined_gt): - return 'name' - elif all_schema_ids_defined_pred and all_schema_ids_defined_gt: - return 'feature_schema_id' - else: - raise ValueError( - "Ground truth and prediction annotations must have set all name or feature ids. " - "Otherwise there is no key to match on. Please update.") - - -def all_have_key(features: List[FeatureSchema]) -> Tuple[bool, bool]: - """ - Checks to make sure that all FeatureSchemas have names set or feature_schema_ids set. - - Args: - features (List[FeatureSchema]) : - - """ - all_names = True - all_schemas = True - for feature in features: - if isinstance(feature, ClassificationAnnotation): - if isinstance(feature.value, Checklist): - all_schemas, all_names = all_have_key(feature.value.answer) - elif isinstance(feature.value, Text): - if feature.name is None: - all_names = False - if feature.feature_schema_id is None: - all_schemas = False - else: - if feature.value.answer.name is None: - all_names = False - if feature.value.answer.feature_schema_id is None: - all_schemas = False - if feature.name is None: - all_names = False - if feature.feature_schema_id is None: - all_schemas = False - return all_schemas, all_names - - -def get_label_pairs(labels_a: list, - labels_b: list, - match_on="uid", - filter_mismatch=False) -> Dict[str, Tuple[Label, Label]]: - """ - This is a function to pairing a list of prediction labels and a list of ground truth labels easier. - There are a few potentiall problems with this function. - We are assuming that the data row `uid` or `external id` have been provided by the user. - However, these particular fields are not required and can be empty. - If this assumption fails, then the user has to determine their own matching strategy. - - Args: - labels_a (list): A collection of labels to match with labels_b - labels_b (list): A collection of labels to match with labels_a - match_on ('uid' or 'external_id'): The data row key to match labels by. Can either be uid or external id. - filter_mismatch (bool): Whether or not to ignore mismatches - - Returns: - A dict containing the union of all either uids or external ids and values as a tuple of the matched labels - - """ - - if match_on not in ['uid', 'external_id']: - raise ValueError("Can only match on `uid` or `exteranl_id`.") - - label_lookup_a = { - getattr(label.data, match_on, None): label for label in labels_a - } - label_lookup_b = { - getattr(label.data, match_on, None): label for label in labels_b - } - all_keys = set(label_lookup_a.keys()).union(label_lookup_b.keys()) - if None in label_lookup_a or None in label_lookup_b: - raise ValueError( - f"One or more of the labels has a data row without the required key {match_on}." - " It cannot be determined which labels match without this information." - f" Either assign {match_on} to each Label or create your own pairing function." - ) - pairs = defaultdict(list) - for key in all_keys: - a, b = label_lookup_a.pop(key, None), label_lookup_b.pop(key, None) - if a is None or b is None: - if not filter_mismatch: - raise ValueError( - f"{match_on} {key} is not available in both LabelLists. " - "Set `filter_mismatch = True` to filter out these examples, assign the ids manually, or create your own matching function." - ) - else: - continue - pairs[key].extend([a, b]) - return pairs - - -def get_feature_pairs( - features_a: List[FeatureSchema], features_b: List[FeatureSchema] -) -> Dict[str, Tuple[List[FeatureSchema], List[FeatureSchema]]]: - """ - Matches features by schema_ids - - Args: - labels_a (List[FeatureSchema]): A list of features to match with features_b - labels_b (List[FeatureSchema]): A list of features to match with features_a - Returns: - The matched features as dict. The key will be the feature name and the value will be - two lists each containing the matched features from each set. - - """ - identifying_key = get_identifying_key(features_a, features_b) - lookup_a, lookup_b = _create_feature_lookup( - features_a, - identifying_key), _create_feature_lookup(features_b, identifying_key) - - keys = set(lookup_a.keys()).union(set(lookup_b.keys())) - result = defaultdict(list) - for key in keys: - result[key].extend([lookup_a[key], lookup_b[key]]) - return result - - -def _create_feature_lookup(features: List[FeatureSchema], - key: str) -> Dict[str, List[FeatureSchema]]: - """ - Groups annotation by name (if available otherwise feature schema id). - - Args: - annotations: List of annotations to group - Returns: - a dict where each key is the feature_schema_id (or name) - and the value is a list of annotations that have that feature_schema_id (or name) - """ - grouped_features = defaultdict(list) - for feature in features: - if isinstance(feature, ClassificationAnnotation): - #checklists - if isinstance(feature.value, Checklist): - for answer in feature.value.answer: - new_answer = Radio(answer=answer) - new_annotation = ClassificationAnnotation( - value=new_answer, - name=answer.name, - feature_schema_id=answer.feature_schema_id) - - grouped_features[getattr(answer, - key)].append(new_annotation) - elif isinstance(feature.value, Text): - grouped_features[getattr(feature, key)].append(feature) - else: - grouped_features[getattr(feature.value.answer, - key)].append(feature) - else: - grouped_features[getattr(feature, key)].append(feature) - return grouped_features - - -def has_no_matching_annotations(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation]): - if len(ground_truths) and not len(predictions): - # No existing predictions but existing ground truths means no matches. - return True - elif not len(ground_truths) and len(predictions): - # No ground truth annotations but there are predictions means no matches - return True - return False - - -def has_no_annotations(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation]): - return not len(ground_truths) and not len(predictions) - ----- -labelbox/data/metrics/confusion_matrix/calculation.py -from typing import List, Optional, Tuple, Union - -import numpy as np - -from ..iou.calculation import _get_mask_pairs, _get_vector_pairs, _get_ner_pairs, miou -from ...annotation_types import (ObjectAnnotation, ClassificationAnnotation, - Mask, Geometry, Checklist, Radio, TextEntity, - ScalarMetricValue, ConfusionMatrixMetricValue) -from ..group import (get_feature_pairs, get_identifying_key, has_no_annotations, - has_no_matching_annotations) - - -def confusion_matrix(ground_truths: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - include_subclasses: bool, - iou: float) -> ConfusionMatrixMetricValue: - """ - Computes the confusion matrix for an arbitrary set of ground truth and predicted annotations. - It first computes the confusion matrix for each metric and then sums across all classes - - Args: - ground_truth : Label containing human annotations or annotations known to be correct - prediction: Label representing model predictions - include_subclasses (bool): Whether or not to include subclasses in the calculation. - If set to True, the iou between two overlapping objects of the same type is 0 if the subclasses are not the same. - iou: minimum overlap between objects for them to count as matching - Returns: - confusion matrix as a list: [TP,FP,TN,FN] - Returns None if there are no annotations in ground_truth or prediction annotations - """ - - annotation_pairs = get_feature_pairs(ground_truths, predictions) - conf_matrix = [ - feature_confusion_matrix(annotation_pair[0], annotation_pair[1], - include_subclasses, iou) - for annotation_pair in annotation_pairs.values() - ] - matrices = [matrix for matrix in conf_matrix if matrix is not None] - return None if not len(matrices) else np.sum(matrices, axis=0).tolist() - - -def feature_confusion_matrix( - ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], - include_subclasses: bool, - iou: float) -> Optional[ConfusionMatrixMetricValue]: - """ - Computes confusion matrix for all features of the same class. - - Args: - ground_truths: List of ground truth annotations belonging to the same class. - predictions: List of annotations belonging to the same class. - include_subclasses (bool): Whether or not to include subclasses in the calculation. - If set to True, the iou between two overlapping objects of the same type is 0 if the subclasses are not the same. - Returns: - confusion matrix as a list: [TP,FP,TN,FN] - Returns None if there are no annotations in ground_truth or prediction annotations - """ - if has_no_matching_annotations(ground_truths, predictions): - return [0, len(predictions), 0, len(ground_truths)] - elif has_no_annotations(ground_truths, predictions): - return None - elif isinstance(predictions[0].value, Mask): - return mask_confusion_matrix(ground_truths, predictions, - include_subclasses, iou) - elif isinstance(predictions[0].value, Geometry): - return vector_confusion_matrix(ground_truths, predictions, - include_subclasses, iou) - elif isinstance(predictions[0].value, TextEntity): - return ner_confusion_matrix(ground_truths, predictions, - include_subclasses, iou) - elif isinstance(predictions[0], ClassificationAnnotation): - return classification_confusion_matrix(ground_truths, predictions) - else: - raise ValueError( - f"Unexpected annotation found. Found {type(predictions[0].value)}") - - -def classification_confusion_matrix( - ground_truths: List[ClassificationAnnotation], - predictions: List[ClassificationAnnotation] -) -> ConfusionMatrixMetricValue: - """ - Computes iou score for all features with the same feature schema id. - Because these predictions and ground truths are already sorted by schema id, - there can only be one of each (or zero if the classification was not predicted or labeled). - - Args: - ground_truths: List of ground truth classification annotations - predictions: List of prediction classification annotations - Returns: - confusion matrix as a list: [TP,FP,TN,FN] - Returns None if there are no annotations in ground_truth or prediction annotations - """ - - if has_no_matching_annotations(ground_truths, predictions): - return [0, len(predictions), 0, len(ground_truths)] - elif has_no_annotations( - ground_truths, - predictions) or len(predictions) > 1 or len(ground_truths) > 1: - # Note that we could return [0,0,0,0] but that will bloat the imports for no reason - return None - - prediction, ground_truth = predictions[0], ground_truths[0] - - if type(prediction) != type(ground_truth): - raise TypeError( - "Classification features must be the same type to compute agreement. " - f"Found `{type(prediction)}` and `{type(ground_truth)}`") - - if isinstance(prediction.value, Radio): - return radio_confusion_matrix(ground_truth.value, prediction.value) - elif isinstance(prediction.value, Checklist): - return checklist_confusion_matrix(ground_truth.value, prediction.value) - else: - raise ValueError( - f"Unsupported subclass. {prediction}. Only Radio and Checklist are supported" - ) - - -def vector_confusion_matrix(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation], - include_subclasses: bool, - iou: float, - buffer=70.) -> Optional[ConfusionMatrixMetricValue]: - """ - Computes confusion matrix for any vector class (point, polygon, line, rectangle). - Ground truths and predictions should all belong to the same class. - - Args: - ground_truths: List of ground truth vector annotations - predictions: List of prediction vector annotations - iou: minimum overlap between objects for them to count as matching - include_subclasses (bool): Whether or not to include subclasses in the calculation. - If set to True, the iou between two overlapping objects of the same type is 0 if the subclasses are not the same. - buffer: How much to buffer point and lines (used for determining if overlap meets iou threshold ) - Returns: - confusion matrix as a list: [TP,FP,TN,FN] - Returns None if there are no annotations in ground_truth or prediction annotations - """ - if has_no_matching_annotations(ground_truths, predictions): - return [0, len(predictions), 0, len(ground_truths)] - elif has_no_annotations(ground_truths, predictions): - return None - - pairs = _get_vector_pairs(ground_truths, predictions, buffer=buffer) - return object_pair_confusion_matrix(pairs, include_subclasses, iou) - - -def object_pair_confusion_matrix(pairs: List[Tuple[ObjectAnnotation, - ObjectAnnotation, - ScalarMetricValue]], - include_subclasses: bool, - iou: float) -> ConfusionMatrixMetricValue: - """ - Computes the confusion matrix for a list of object annotation pairs. - Performs greedy matching of pairs. - - Args: - pairs : A list of object annotation pairs with an iou score. - This is used to determine matching priority (or if objects are matching at all) since objects can only be matched once. - iou : iou threshold to deterine if objects are matching - Returns: - confusion matrix as a list: [TP,FP,TN,FN] - """ - pairs.sort(key=lambda triplet: triplet[2], reverse=True) - prediction_ids = set() - ground_truth_ids = set() - matched_predictions = set() - matched_ground_truths = set() - - for ground_truth, prediction, agreement in pairs: - prediction_id = id(prediction) - ground_truth_id = id(ground_truth) - prediction_ids.add(prediction_id) - ground_truth_ids.add(ground_truth_id) - - if agreement > iou and \ - prediction_id not in matched_predictions and \ - ground_truth_id not in matched_ground_truths: - if include_subclasses and (ground_truth.classifications or - prediction.classifications): - if miou(ground_truth.classifications, - prediction.classifications, - include_subclasses=False) < 1.: - # Incorrect if the subclasses don't 100% agree then there is no match - continue - matched_predictions.add(prediction_id) - matched_ground_truths.add(ground_truth_id) - tps = len(matched_ground_truths) - - fps = len(prediction_ids.difference(matched_predictions)) - fns = len(ground_truth_ids.difference(matched_ground_truths)) - # Not defined for object detection. - tns = 0 - return [tps, fps, tns, fns] - - -def radio_confusion_matrix(ground_truth: Radio, - prediction: Radio) -> ConfusionMatrixMetricValue: - """ - Calculates confusion between ground truth and predicted radio values - - Calculation: - - TNs aren't defined because we don't know how many other classes exist - - When P == L, then we get [1,0,0,0] - - when P != L, we get [0,1,0,1] - - This is because we are aggregating the stats for the entire radio. Not for each class. - Since we are not tracking TNs (P == L) only adds to TP. - We are not tracking TNs because the number of TNs is equal to the number of classes which we do not know - from just looking at the predictions and labels. Also TNs are necessary for precision/recall/f1. - """ - key = get_identifying_key([prediction.answer], [ground_truth.answer]) - prediction_id = getattr(prediction.answer, key) - ground_truth_id = getattr(ground_truth.answer, key) - return [1, 0, 0, 0] if prediction_id == ground_truth_id else [0, 1, 0, 1] - - -def checklist_confusion_matrix( - ground_truth: Checklist, - prediction: Checklist) -> ConfusionMatrixMetricValue: - """ - Calculates agreement between ground truth and predicted checklist items: - - Calculation: - - When a prediction matches a label that counts as a true postivie. - - When a prediction was made and does not have a corresponding label this is counted as a false postivie - - When a label does not have a corresponding prediction this is counted as a false negative - - We are also not tracking TNs since we don't know the number of possible classes - (and they aren't necessary for precision/recall/f1). - - """ - key = get_identifying_key(prediction.answer, ground_truth.answer) - schema_ids_pred = {getattr(answer, key) for answer in prediction.answer} - schema_ids_label = {getattr(answer, key) for answer in ground_truth.answer} - agree = schema_ids_label & schema_ids_pred - all_selected = schema_ids_label | schema_ids_pred - disagree = all_selected.difference(agree) - fps = len({x for x in disagree if x in schema_ids_pred}) - fns = len({x for x in disagree if x in schema_ids_label}) - tps = len(agree) - return [tps, fps, 0, fns] - - -def mask_confusion_matrix(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation], - include_subclasses: bool, - iou: float) -> Optional[ScalarMetricValue]: - """ - Computes confusion matrix metric for two masks - - Important: - - If including subclasses in the calculation, then the metrics are computed the same as if it were object detection. - - Each mask is its own instance. Otherwise this metric is computed as pixel level annotations. - - Args: - ground_truths: List of ground truth mask annotations - predictions: List of prediction mask annotations - Returns: - confusion matrix as a list: [TP,FP,TN,FN] - """ - if has_no_matching_annotations(ground_truths, predictions): - return [0, len(predictions), 0, len(ground_truths)] - elif has_no_annotations(ground_truths, predictions): - return None - - pairs = _get_mask_pairs(ground_truths, predictions) - return object_pair_confusion_matrix(pairs, - include_subclasses=include_subclasses, - iou=iou) - - -def ner_confusion_matrix(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation], - include_subclasses: bool, - iou: float) -> Optional[ConfusionMatrixMetricValue]: - """Computes confusion matrix metric between two lists of TextEntity objects - - Args: - ground_truths: List of ground truth mask annotations - predictions: List of prediction mask annotations - Returns: - confusion matrix as a list: [TP,FP,TN,FN] - """ - if has_no_matching_annotations(ground_truths, predictions): - return [0, len(predictions), 0, len(ground_truths)] - elif has_no_annotations(ground_truths, predictions): - return None - pairs = _get_ner_pairs(ground_truths, predictions) - return object_pair_confusion_matrix(pairs, include_subclasses, iou) - ----- -labelbox/data/metrics/confusion_matrix/__init__.py -from .calculation import * -from .confusion_matrix import * - ----- -labelbox/data/metrics/confusion_matrix/confusion_matrix.py -# type: ignore -from collections import defaultdict -from labelbox.data.annotation_types import feature -from labelbox.data.annotation_types.metrics import ConfusionMatrixMetric -from typing import List, Optional, Union -from ...annotation_types import (Label, ObjectAnnotation, - ClassificationAnnotation) - -from ..group import get_feature_pairs -from .calculation import confusion_matrix -from .calculation import feature_confusion_matrix -import numpy as np - - -def confusion_matrix_metric(ground_truths: List[Union[ - ObjectAnnotation, ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - include_subclasses=False, - iou=0.5) -> List[ConfusionMatrixMetric]: - """ - Computes confusion matrix metrics between two sets of annotations. - These annotations should relate to the same data (image/video). - On the front end these will be displayed as precision, recall, and f1 scores. - - Args: - ground_truth : Label containing human annotations or annotations known to be correct - prediction: Label representing model predictions - include_subclasses (bool): Whether or not to include subclasses in the calculation. - If set to True, the iou between two overlapping objects of the same type is 0 if the subclasses are not the same. - Returns: - Returns a list of ConfusionMatrixMetrics. Will be empty if there were no predictions and labels. Otherwise a single metric will be returned. - """ - if not (0. < iou < 1.): - raise ValueError("iou must be between 0 and 1") - - value = confusion_matrix(ground_truths, predictions, include_subclasses, - iou) - # If both gt and preds are empty there is no metric - if value is None: - return [] - - metric_name = _get_metric_name(ground_truths, predictions, iou) - return [ConfusionMatrixMetric(metric_name=metric_name, value=value)] - - -def feature_confusion_matrix_metric( - ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], - include_subclasses=False, - iou: float = 0.5, -) -> List[ConfusionMatrixMetric]: - """ - Computes the confusion matrix metrics for each type of class in the list of annotations. - These annotations should relate to the same data (image/video). - On the front end these will be displayed as precision, recall, and f1 scores. - - Args: - ground_truth : Label containing human annotations or annotations known to be correct - prediction: Label representing model predictions - include_subclasses (bool): Whether or not to include subclasses in the calculation. - If set to True, the iou between two overlapping objects of the same type is 0 if the subclasses are not the same. - Returns: - Returns a list of ConfusionMatrixMetrics. - There will be one metric for each class in the union of ground truth and prediction classes. - """ - # Classifications are supported because we just take a naive approach to them.. - annotation_pairs = get_feature_pairs(ground_truths, predictions) - metrics = [] - for key in annotation_pairs: - value = feature_confusion_matrix(annotation_pairs[key][0], - annotation_pairs[key][1], - include_subclasses, iou) - if value is None: - continue - - metric_name = _get_metric_name(annotation_pairs[key][0], - annotation_pairs[key][1], iou) - metrics.append( - ConfusionMatrixMetric(metric_name=metric_name, - feature_name=key, - value=value)) - return metrics - - -def _get_metric_name(ground_truths: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - iou: float): - - if _is_classification(ground_truths, predictions): - return "classification" - - return f"{int(iou*100)}pct_iou" - - -def _is_classification(ground_truths: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]]): - # Check if either the prediction or label contains a classification annotation - return (len(predictions) and - isinstance(predictions[0], ClassificationAnnotation) or - len(ground_truths) and - isinstance(ground_truths[0], ClassificationAnnotation)) - ----- -labelbox/data/metrics/iou/calculation.py -from typing import List, Optional, Tuple, Union -from itertools import product - -import numpy as np -from shapely.geometry import Polygon - -from ..group import get_feature_pairs, get_identifying_key, has_no_annotations, has_no_matching_annotations -from ...annotation_types import (ObjectAnnotation, ClassificationAnnotation, - Mask, Geometry, Point, Line, Checklist, Text, - TextEntity, Radio, ScalarMetricValue) - - -def miou(ground_truths: List[Union[ObjectAnnotation, ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, ClassificationAnnotation]], - include_subclasses: bool) -> Optional[ScalarMetricValue]: - """ - Computes miou for an arbitrary set of ground truth and predicted annotations. - It first computes the iou for each metric and then takes the average (weighting each class equally) - - Args: - ground_truth : Label containing human annotations or annotations known to be correct - prediction: Label representing model predictions - include_subclasses (bool): Whether or not to include subclasses in the iou calculation. - If set to True, the iou between two overlapping objects of the same type is 0 if the subclasses are not the same. - Returns: - float indicating the iou score for all features represented in the annotations passed to this function. - Returns None if there are no annotations in ground_truth or prediction annotations - """ - annotation_pairs = get_feature_pairs(predictions, ground_truths) - ious = [ - feature_miou(annotation_pair[0], annotation_pair[1], include_subclasses) - for annotation_pair in annotation_pairs.values() - ] - ious = [iou for iou in ious if iou is not None] - return None if not len(ious) else np.mean(ious) - - -def feature_miou(ground_truths: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - include_subclasses: bool) -> Optional[ScalarMetricValue]: - """ - Computes iou score for all features of the same class. - - Args: - ground_truths: List of ground truth annotations with the same feature schema. - predictions: List of annotations with the same feature schema. - include_subclasses (bool): Whether or not to include subclasses in the iou calculation. - If set to True, the iou between two overlapping objects of the same type is 0 if the subclasses are not the same. - Returns: - float representing the iou score for the feature type if score can be computed otherwise None. - """ - if has_no_matching_annotations(ground_truths, predictions): - return 0. - elif has_no_annotations(ground_truths, predictions): - return None - elif isinstance(predictions[0].value, Mask): - return mask_miou(ground_truths, predictions, include_subclasses) - elif isinstance(predictions[0].value, Geometry): - return vector_miou(ground_truths, predictions, include_subclasses) - elif isinstance(predictions[0], ClassificationAnnotation): - return classification_miou(ground_truths, predictions) - elif isinstance(predictions[0].value, TextEntity): - return ner_miou(ground_truths, predictions, include_subclasses) - else: - raise ValueError( - f"Unexpected annotation found. Found {type(predictions[0].value)}") - - -def vector_miou(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation], - include_subclasses: bool, - buffer=70.) -> Optional[ScalarMetricValue]: - """ - Computes iou score for all features with the same feature schema id. - Calculation includes subclassifications. - - Args: - ground_truths: List of ground truth vector annotations - predictions: List of prediction vector annotations - Returns: - float representing the iou score for the feature type. - If there are no matches then this returns none - """ - if has_no_matching_annotations(ground_truths, predictions): - return 0. - elif has_no_annotations(ground_truths, predictions): - return None - pairs = _get_vector_pairs(ground_truths, predictions, buffer=buffer) - return object_pair_miou(pairs, include_subclasses) - - -def object_pair_miou(pairs: List[Tuple[ObjectAnnotation, ObjectAnnotation, - ScalarMetricValue]], - include_subclasses) -> ScalarMetricValue: - pairs.sort(key=lambda triplet: triplet[2], reverse=True) - solution_agreements = [] - solution_features = set() - all_features = set() - for prediction, ground_truth, agreement in pairs: - all_features.update({id(prediction), id(ground_truth)}) - if id(prediction) not in solution_features and id( - ground_truth) not in solution_features: - solution_features.update({id(prediction), id(ground_truth)}) - if include_subclasses: - classification_iou = miou(prediction.classifications, - ground_truth.classifications, - include_subclasses=False) - classification_iou = classification_iou if classification_iou is not None else agreement - solution_agreements.append( - (agreement + classification_iou) / 2.) - else: - solution_agreements.append(agreement) - - # Add zeros for unmatched Features - solution_agreements.extend([0.0] * - (len(all_features) - len(solution_features))) - return np.mean(solution_agreements) - - -def mask_miou(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation], - include_subclasses: bool) -> Optional[ScalarMetricValue]: - """ - Computes iou score for all features with the same feature schema id. - Calculation includes subclassifications. - - Args: - ground_truths: List of ground truth mask annotations - predictions: List of prediction mask annotations - Returns: - float representing the iou score for the masks - """ - if has_no_matching_annotations(ground_truths, predictions): - return 0. - elif has_no_annotations(ground_truths, predictions): - return None - - if include_subclasses: - pairs = _get_mask_pairs(ground_truths, predictions) - return object_pair_miou(pairs, include_subclasses=include_subclasses) - - prediction_np = np.max([pred.value.draw(color=1) for pred in predictions], - axis=0) - ground_truth_np = np.max( - [ground_truth.value.draw(color=1) for ground_truth in ground_truths], - axis=0) - if prediction_np.shape != ground_truth_np.shape: - raise ValueError( - "Prediction and mask must have the same shape." - f" Found {prediction_np.shape}/{ground_truth_np.shape}.") - - return _mask_iou(ground_truth_np, prediction_np) - - -def classification_miou( - ground_truths: List[ClassificationAnnotation], - predictions: List[ClassificationAnnotation]) -> ScalarMetricValue: - """ - Computes iou score for all features with the same feature schema id. - - Args: - ground_truths: List of ground truth classification annotations - predictions: List of prediction classification annotations - Returns: - float representing the iou score for the classification - """ - - if len(predictions) != len(ground_truths) != 1: - return 0. - - prediction, ground_truth = predictions[0], ground_truths[0] - - if type(prediction) != type(ground_truth): - raise TypeError( - "Classification features must be the same type to compute agreement. " - f"Found `{type(prediction)}` and `{type(ground_truth)}`") - - if isinstance(prediction.value, Text): - return text_iou(ground_truth.value, prediction.value) - elif isinstance(prediction.value, Radio): - return radio_iou(ground_truth.value, prediction.value) - elif isinstance(prediction.value, Checklist): - return checklist_iou(ground_truth.value, prediction.value) - else: - raise ValueError(f"Unsupported subclass. {prediction}.") - - -def radio_iou(ground_truth: Radio, prediction: Radio) -> ScalarMetricValue: - """ - Calculates agreement between ground truth and predicted radio values - """ - key = get_identifying_key([prediction.answer], [ground_truth.answer]) - return float( - getattr(prediction.answer, key) == getattr(ground_truth.answer, key)) - - -def text_iou(ground_truth: Text, prediction: Text) -> ScalarMetricValue: - """ - Calculates agreement between ground truth and predicted text - """ - return float(prediction.answer == ground_truth.answer) - - -def checklist_iou(ground_truth: Checklist, - prediction: Checklist) -> ScalarMetricValue: - """ - Calculates agreement between ground truth and predicted checklist items - """ - key = get_identifying_key(prediction.answer, ground_truth.answer) - schema_ids_pred = {getattr(answer, key) for answer in prediction.answer} - schema_ids_label = {getattr(answer, key) for answer in ground_truth.answer} - return float( - len(schema_ids_label & schema_ids_pred) / - len(schema_ids_label | schema_ids_pred)) - - -def _get_vector_pairs( - ground_truths: List[ObjectAnnotation], predictions: List[ObjectAnnotation], - buffer: float -) -> List[Tuple[ObjectAnnotation, ObjectAnnotation, ScalarMetricValue]]: - """ - # Get iou score for all pairs of ground truths and predictions - """ - pairs = [] - for ground_truth, prediction in product(ground_truths, predictions): - if isinstance(prediction.value, Geometry) and isinstance( - ground_truth.value, Geometry): - if isinstance(prediction.value, (Line, Point)): - - score = _polygon_iou(prediction.value.shapely.buffer(buffer), - ground_truth.value.shapely.buffer(buffer)) - else: - score = _polygon_iou(prediction.value.shapely, - ground_truth.value.shapely) - pairs.append((ground_truth, prediction, score)) - return pairs - - -def _get_mask_pairs( - ground_truths: List[ObjectAnnotation], predictions: List[ObjectAnnotation] -) -> List[Tuple[ObjectAnnotation, ObjectAnnotation, ScalarMetricValue]]: - """ - # Get iou score for all pairs of ground truths and predictions - """ - pairs = [] - for ground_truth, prediction in product(ground_truths, predictions): - if isinstance(prediction.value, Mask) and isinstance( - ground_truth.value, Mask): - score = _mask_iou(prediction.value.draw(color=1), - ground_truth.value.draw(color=1)) - pairs.append((ground_truth, prediction, score)) - return pairs - - -def _polygon_iou(poly1: Polygon, poly2: Polygon) -> ScalarMetricValue: - """Computes iou between two shapely polygons.""" - poly1, poly2 = _ensure_valid_poly(poly1), _ensure_valid_poly(poly2) - if poly1.intersects(poly2): - return poly1.intersection(poly2).area / poly1.union(poly2).area - return 0. - - -def _ensure_valid_poly(poly): - if not poly.is_valid: - return poly.buffer(0) - return poly - - -def _mask_iou(mask1: np.ndarray, mask2: np.ndarray) -> ScalarMetricValue: - """Computes iou between two binary segmentation masks.""" - return np.sum(mask1 & mask2) / np.sum(mask1 | mask2) - - -def _get_ner_pairs( - ground_truths: List[ObjectAnnotation], predictions: List[ObjectAnnotation] -) -> List[Tuple[ObjectAnnotation, ObjectAnnotation, ScalarMetricValue]]: - """Get iou score for all possible pairs of ground truths and predictions""" - pairs = [] - for ground_truth, prediction in product(ground_truths, predictions): - score = _ner_iou(ground_truth.value, prediction.value) - pairs.append((ground_truth, prediction, score)) - return pairs - - -def _ner_iou(ner1: TextEntity, ner2: TextEntity): - """Computes iou between two text entity annotations""" - intersection_start, intersection_end = max(ner1.start, ner2.start), min( - ner1.end, ner2.end) - union_start, union_end = min(ner1.start, - ner2.start), max(ner1.end, ner2.end) - #edge case of only one character in text - if union_start == union_end: - return 1 - #if there is no intersection - if intersection_start > intersection_end: - return 0 - return (intersection_end - intersection_start) / (union_end - union_start) - - -def ner_miou(ground_truths: List[ObjectAnnotation], - predictions: List[ObjectAnnotation], - include_subclasses: bool) -> Optional[ScalarMetricValue]: - """ - Computes iou score for all features with the same feature schema id. - Calculation includes subclassifications. - - Args: - ground_truths: List of ground truth ner annotations - predictions: List of prediction ner annotations - Returns: - float representing the iou score for the feature type. - If there are no matches then this returns none - """ - if has_no_matching_annotations(ground_truths, predictions): - return 0. - elif has_no_annotations(ground_truths, predictions): - return None - pairs = _get_ner_pairs(ground_truths, predictions) - return object_pair_miou(pairs, include_subclasses) ----- -labelbox/data/metrics/iou/__init__.py -from .calculation import * -from .iou import * - ----- -labelbox/data/metrics/iou/iou.py -# type: ignore -from labelbox.data.annotation_types.metrics.scalar import ScalarMetric -from typing import List, Optional, Union -from ...annotation_types import (Label, ObjectAnnotation, - ClassificationAnnotation) - -from ..group import get_feature_pairs -from .calculation import feature_miou -from .calculation import miou - - -def miou_metric(ground_truths: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - include_subclasses=False) -> List[ScalarMetric]: - """ - Computes miou between two sets of annotations. - These annotations should relate to the same data (image/video). - Each class in the annotation list is weighted equally in the iou score. - - Args: - ground_truth : Label containing human annotations or annotations known to be correct - prediction: Label representing model predictions - include_subclasses (bool): Whether or not to include subclasses in the iou calculation. - If set to True, the iou between two overlapping objects of the same type is 0 if the subclasses are not the same. - Returns: - Returns a list of ScalarMetrics. Will be empty if there were no predictions and labels. Otherwise a single metric will be returned. - """ - iou = miou(ground_truths, predictions, include_subclasses) - # If both gt and preds are empty there is no metric - if iou is None: - return [] - return [ScalarMetric(metric_name="custom_iou", value=iou)] - - -def feature_miou_metric(ground_truths: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - predictions: List[Union[ObjectAnnotation, - ClassificationAnnotation]], - include_subclasses=True) -> List[ScalarMetric]: - """ - Computes the miou for each type of class in the list of annotations. - These annotations should relate to the same data (image/video). - - Args: - ground_truth : Label containing human annotations or annotations known to be correct - prediction: Label representing model predictions - include_subclasses (bool): Whether or not to include subclasses in the iou calculation. - If set to True, the iou between two overlapping objects of the same type is 0 if the subclasses are not the same. - Returns: - Returns a list of ScalarMetrics. - There will be one metric for each class in the union of ground truth and prediction classes. - """ - # Classifications are supported because we just take a naive approach to them.. - annotation_pairs = get_feature_pairs(predictions, ground_truths) - metrics = [] - for key in annotation_pairs: - - value = feature_miou(annotation_pairs[key][0], annotation_pairs[key][1], - include_subclasses) - if value is None: - continue - metrics.append( - ScalarMetric(metric_name="custom_iou", - feature_name=key, - value=value)) - return metrics - - -def data_row_miou(ground_truth: Label, - prediction: Label, - include_subclasses=False) -> Optional[float]: - """ - - This function is no longer supported. Use miou() for raw values or miou_metric() for the metric - - Calculates iou for two labels corresponding to the same data row. - - Args: - ground_truth : Label containing human annotations or annotations known to be correct - prediction: Label representing model predictions - Returns: - float indicating the iou score for this data row. - Returns None if there are no annotations in ground_truth or prediction Labels - """ - return miou(ground_truth.annotations, prediction.annotations, - include_subclasses) - ----- -labelbox/data/serialization/__init__.py -from .labelbox_v1 import LBV1Converter -from .ndjson import NDJsonConverter -from .coco import COCOConverter - ----- -labelbox/data/serialization/ndjson/relationship.py -from typing import Union -from labelbox import pydantic_compat -from .base import NDAnnotation, DataRow -from ...annotation_types.data import ImageData, TextData -from ...annotation_types.relationship import RelationshipAnnotation -from ...annotation_types.relationship import Relationship -from .objects import NDObjectType -from .base import DataRow - -SUPPORTED_ANNOTATIONS = NDObjectType - - -class _Relationship(pydantic_compat.BaseModel): - source: str - target: str - type: str - - -class NDRelationship(NDAnnotation): - relationship: _Relationship - - @staticmethod - def to_common(annotation: "NDRelationship", source: SUPPORTED_ANNOTATIONS, - target: SUPPORTED_ANNOTATIONS) -> RelationshipAnnotation: - return RelationshipAnnotation(name=annotation.name, - value=Relationship( - source=source, - target=target, - type=Relationship.Type( - annotation.relationship.type)), - extra={'uuid': annotation.uuid}, - feature_schema_id=annotation.schema_id) - - @classmethod - def from_common(cls, annotation: RelationshipAnnotation, - data: Union[ImageData, TextData]) -> "NDRelationship": - relationship = annotation.value - return cls(uuid=str(annotation._uuid), - name=annotation.name, - dataRow=DataRow(id=data.uid, global_key=data.global_key), - relationship=_Relationship( - source=str(relationship.source._uuid), - target=str(relationship.target._uuid), - type=relationship.type.value)) - ----- -labelbox/data/serialization/ndjson/classification.py -from typing import Any, Dict, List, Union, Optional - -from labelbox import pydantic_compat -from labelbox.data.mixins import ConfidenceMixin, CustomMetric, CustomMetricsMixin -from labelbox.data.serialization.ndjson.base import DataRow, NDAnnotation - -from labelbox.utils import camel_case -from ...annotation_types.annotation import ClassificationAnnotation -from ...annotation_types.video import VideoClassificationAnnotation -from ...annotation_types.classification.classification import ClassificationAnswer, Dropdown, Text, Checklist, Radio -from ...annotation_types.types import Cuid -from ...annotation_types.data import TextData, VideoData, ImageData - - -class NDAnswer(ConfidenceMixin, CustomMetricsMixin): - name: Optional[str] = None - schema_id: Optional[Cuid] = None - classifications: Optional[List['NDSubclassificationType']] = [] - - @pydantic_compat.root_validator() - def must_set_one(cls, values): - if ('schema_id' not in values or values['schema_id'] - is None) and ('name' not in values or values['name'] is None): - raise ValueError("Schema id or name are not set. Set either one.") - return values - - def dict(self, *args, **kwargs): - res = super().dict(*args, **kwargs) - if 'name' in res and res['name'] is None: - res.pop('name') - if 'schemaId' in res and res['schemaId'] is None: - res.pop('schemaId') - if self.classifications is None or len(self.classifications) == 0: - res.pop('classifications') - else: - res['classifications'] = [ - c.dict(*args, **kwargs) for c in self.classifications - ] - return res - - class Config: - allow_population_by_field_name = True - alias_generator = camel_case - - -class FrameLocation(pydantic_compat.BaseModel): - end: int - start: int - - -class VideoSupported(pydantic_compat.BaseModel): - # Note that frames are only allowed as top level inferences for video - frames: Optional[List[FrameLocation]] = None - - def dict(self, *args, **kwargs): - res = super().dict(*args, **kwargs) - # This means these are no video frames .. - if self.frames is None: - res.pop('frames') - return res - - -class NDTextSubclass(NDAnswer): - answer: str - - def to_common(self) -> Text: - return Text(answer=self.answer, - confidence=self.confidence, - custom_metrics=self.custom_metrics) - - @classmethod - def from_common(cls, text: Text, name: str, - feature_schema_id: Cuid) -> "NDTextSubclass": - return cls( - answer=text.answer, - name=name, - schema_id=feature_schema_id, - confidence=text.confidence, - custom_metrics=text.custom_metrics, - ) - - -class NDChecklistSubclass(NDAnswer): - answer: List[NDAnswer] = pydantic_compat.Field(..., alias='answers') - - def to_common(self) -> Checklist: - - return Checklist(answer=[ - ClassificationAnswer(name=answer.name, - feature_schema_id=answer.schema_id, - confidence=answer.confidence, - classifications=[ - NDSubclassification.to_common(annot) - for annot in answer.classifications - ], - custom_metrics=answer.custom_metrics) - for answer in self.answer - ]) - - @classmethod - def from_common(cls, checklist: Checklist, name: str, - feature_schema_id: Cuid) -> "NDChecklistSubclass": - return cls(answer=[ - NDAnswer(name=answer.name, - schema_id=answer.feature_schema_id, - confidence=answer.confidence, - classifications=[ - NDSubclassification.from_common(annot) - for annot in answer.classifications - ], - custom_metrics=answer.custom_metrics) - for answer in checklist.answer - ], - name=name, - schema_id=feature_schema_id) - - def dict(self, *args, **kwargs): - res = super().dict(*args, **kwargs) - if 'answers' in res: - res['answer'] = res.pop('answers') - return res - - -class NDRadioSubclass(NDAnswer): - answer: NDAnswer - - def to_common(self) -> Radio: - return Radio(answer=ClassificationAnswer( - name=self.answer.name, - feature_schema_id=self.answer.schema_id, - confidence=self.answer.confidence, - classifications=[ - NDSubclassification.to_common(annot) - for annot in self.answer.classifications - ], - custom_metrics=self.answer.custom_metrics)) - - @classmethod - def from_common(cls, radio: Radio, name: str, - feature_schema_id: Cuid) -> "NDRadioSubclass": - return cls(answer=NDAnswer(name=radio.answer.name, - schema_id=radio.answer.feature_schema_id, - confidence=radio.answer.confidence, - classifications=[ - NDSubclassification.from_common(annot) - for annot in radio.answer.classifications - ], - custom_metrics=radio.answer.custom_metrics), - name=name, - schema_id=feature_schema_id) - - -# ====== End of subclasses - - -class NDText(NDAnnotation, NDTextSubclass): - - @classmethod - def from_common(cls, - uuid: str, - text: Text, - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[TextData, ImageData], - message_id: str, - confidence: Optional[float] = None) -> "NDText": - return cls( - answer=text.answer, - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - message_id=message_id, - confidence=text.confidence, - custom_metrics=text.custom_metrics, - ) - - -class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported): - - @classmethod - def from_common( - cls, - uuid: str, - checklist: Checklist, - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[VideoData, TextData, ImageData], - message_id: str, - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None - ) -> "NDChecklist": - - return cls(answer=[ - NDAnswer(name=answer.name, - schema_id=answer.feature_schema_id, - confidence=answer.confidence, - classifications=[ - NDSubclassification.from_common(annot) - for annot in answer.classifications - ], - custom_metrics=answer.custom_metrics) - for answer in checklist.answer - ], - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - frames=extra.get('frames'), - message_id=message_id, - confidence=confidence) - - -class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported): - - @classmethod - def from_common( - cls, - uuid: str, - radio: Radio, - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[VideoData, TextData, ImageData], - message_id: str, - confidence: Optional[float] = None, - ) -> "NDRadio": - return cls(answer=NDAnswer(name=radio.answer.name, - schema_id=radio.answer.feature_schema_id, - confidence=radio.answer.confidence, - classifications=[ - NDSubclassification.from_common(annot) - for annot in radio.answer.classifications - ], - custom_metrics=radio.answer.custom_metrics), - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - frames=extra.get('frames'), - message_id=message_id, - confidence=confidence) - - -class NDSubclassification: - - @classmethod - def from_common( - cls, annotation: ClassificationAnnotation - ) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]: - classify_obj = cls.lookup_subclassification(annotation) - if classify_obj is None: - raise TypeError( - f"Unable to convert object to MAL format. `{type(annotation.value)}`" - ) - return classify_obj.from_common(annotation.value, annotation.name, - annotation.feature_schema_id) - - @staticmethod - def to_common( - annotation: "NDClassificationType") -> ClassificationAnnotation: - return ClassificationAnnotation(value=annotation.to_common(), - name=annotation.name, - feature_schema_id=annotation.schema_id) - - @staticmethod - def lookup_subclassification( - annotation: ClassificationAnnotation - ) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]: - if isinstance(annotation.value, Dropdown): - raise TypeError("Dropdowns are not supported for MAL.") - return { - Text: NDTextSubclass, - Checklist: NDChecklistSubclass, - Radio: NDRadioSubclass, - }.get(type(annotation.value)) - - -class NDClassification: - - @staticmethod - def to_common( - annotation: "NDClassificationType" - ) -> Union[ClassificationAnnotation, VideoClassificationAnnotation]: - common = ClassificationAnnotation( - value=annotation.to_common(), - name=annotation.name, - feature_schema_id=annotation.schema_id, - extra={'uuid': annotation.uuid}, - message_id=annotation.message_id, - confidence=annotation.confidence, - ) - - if getattr(annotation, 'frames', None) is None: - return [common] - results = [] - for frame in annotation.frames: - for idx in range(frame.start, frame.end + 1, 1): - results.append( - VideoClassificationAnnotation(frame=idx, **common.dict())) - return results - - @classmethod - def from_common( - cls, annotation: Union[ClassificationAnnotation, - VideoClassificationAnnotation], - data: Union[VideoData, TextData, ImageData] - ) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]: - classify_obj = cls.lookup_classification(annotation) - if classify_obj is None: - raise TypeError( - f"Unable to convert object to MAL format. `{type(annotation.value)}`" - ) - return classify_obj.from_common(str(annotation._uuid), annotation.value, - annotation.name, - annotation.feature_schema_id, - annotation.extra, data, - annotation.message_id, - annotation.confidence) - - @staticmethod - def lookup_classification( - annotation: Union[ClassificationAnnotation, - VideoClassificationAnnotation] - ) -> Union[NDText, NDChecklist, NDRadio]: - if isinstance(annotation.value, Dropdown): - raise TypeError("Dropdowns are not supported for MAL.") - return { - Text: NDText, - Checklist: NDChecklist, - Radio: NDRadio - }.get(type(annotation.value)) - - -# Make sure to keep NDChecklistSubclass prior to NDRadioSubclass in the list, -# otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used -NDSubclassificationType = Union[NDChecklistSubclass, NDRadioSubclass, - NDTextSubclass] - -NDAnswer.update_forward_refs() -NDChecklistSubclass.update_forward_refs() -NDChecklist.update_forward_refs() -NDRadioSubclass.update_forward_refs() -NDRadio.update_forward_refs() -NDText.update_forward_refs() -NDTextSubclass.update_forward_refs() - -# Make sure to keep NDChecklist prior to NDRadio in the list, -# otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used -NDClassificationType = Union[NDChecklist, NDRadio, NDText] - ----- -labelbox/data/serialization/ndjson/converter.py -import copy -import logging -import uuid -from collections import defaultdict, deque -from typing import Any, Deque, Dict, Generator, Iterable, List, Set, Union - -from labelbox.data.annotation_types.annotation import ObjectAnnotation -from labelbox.data.annotation_types.classification.classification import ( - ClassificationAnnotation,) -from labelbox.data.annotation_types.metrics.confusion_matrix import ( - ConfusionMatrixMetric,) -from labelbox.data.annotation_types.metrics.scalar import ScalarMetric -from labelbox.data.annotation_types.video import VideoMaskAnnotation - -from ...annotation_types.collection import LabelCollection, LabelGenerator -from ...annotation_types.relationship import RelationshipAnnotation -from .label import NDLabel - -logger = logging.getLogger(__name__) - -IGNORE_IF_NONE = ["page", "unit", "messageId"] - - -class NDJsonConverter: - - @staticmethod - def deserialize(json_data: Iterable[Dict[str, Any]]) -> LabelGenerator: - """ - Converts ndjson data (prediction import format) into the common labelbox format. - - Args: - json_data: An iterable representing the ndjson data - Returns: - LabelGenerator containing the ndjson data. - """ - data = NDLabel(**{"annotations": json_data}) - res = data.to_common() - return res - - @staticmethod - def serialize( - labels: LabelCollection) -> Generator[Dict[str, Any], None, None]: - """ - Converts a labelbox common object to the labelbox ndjson format (prediction import format) - - Note that this function might fail for objects that are not supported by mal. - Not all edge cases are handling by custom exceptions, if you get a cryptic pydantic error message it is probably due to this. - We will continue to improve the error messages and add helper functions to deal with this. - - Args: - labels: Either a list of Label objects or a LabelGenerator - Returns: - A generator for accessing the ndjson representation of the data - """ - used_uuids: Set[uuid.UUID] = set() - - relationship_uuids: Dict[uuid.UUID, - Deque[uuid.UUID]] = defaultdict(deque) - - # UUIDs are private properties used to enhance UX when defining relationships. - # They are created for all annotations, but only utilized for relationships. - # To avoid overwriting, UUIDs must be unique across labels. - # Non-relationship annotation UUIDs are regenerated when they are reused. - # For relationship annotations, during first pass, we update the UUIDs of the source and target annotations. - # During the second pass, we update the UUIDs of the annotations referenced by the relationship annotations. - for label in labels: - uuid_safe_annotations: List[Union[ - ClassificationAnnotation, - ObjectAnnotation, - VideoMaskAnnotation, - ScalarMetric, - ConfusionMatrixMetric, - RelationshipAnnotation, - ]] = [] - # First pass to get all RelatiohnshipAnnotaitons - # and update the UUIDs of the source and target annotations - for annotation in label.annotations: - if isinstance(annotation, RelationshipAnnotation): - annotation = copy.deepcopy(annotation) - new_source_uuid = uuid.uuid4() - new_target_uuid = uuid.uuid4() - relationship_uuids[annotation.value.source._uuid].append( - new_source_uuid) - relationship_uuids[annotation.value.target._uuid].append( - new_target_uuid) - annotation.value.source._uuid = new_source_uuid - annotation.value.target._uuid = new_target_uuid - if annotation._uuid in used_uuids: - annotation._uuid = uuid.uuid4() - used_uuids.add(annotation._uuid) - uuid_safe_annotations.append(annotation) - # Second pass to update UUIDs for annotations referenced by RelationshipAnnotations - for annotation in label.annotations: - if (not isinstance(annotation, RelationshipAnnotation) and - hasattr(annotation, "_uuid")): - annotation = copy.deepcopy(annotation) - next_uuids = relationship_uuids[annotation._uuid] - if len(next_uuids) > 0: - annotation._uuid = next_uuids.popleft() - - if annotation._uuid in used_uuids: - annotation._uuid = uuid.uuid4() - used_uuids.add(annotation._uuid) - uuid_safe_annotations.append(annotation) - else: - if not isinstance(annotation, RelationshipAnnotation): - uuid_safe_annotations.append(annotation) - label.annotations = uuid_safe_annotations - for example in NDLabel.from_common([label]): - annotation_uuid = getattr(example, "uuid", None) - - res = example.dict( - by_alias=True, - exclude={"uuid"} if annotation_uuid == "None" else None, - ) - for k, v in list(res.items()): - if k in IGNORE_IF_NONE and v is None: - del res[k] - yield res - ----- -labelbox/data/serialization/ndjson/__init__.py -from .converter import NDJsonConverter - ----- -labelbox/data/serialization/ndjson/metric.py -from typing import Optional, Union, Type - -from labelbox.data.annotation_types.data import ImageData, TextData -from labelbox.data.serialization.ndjson.base import DataRow, NDJsonBase -from labelbox.data.annotation_types.metrics.scalar import ( - ScalarMetric, ScalarMetricAggregation, ScalarMetricValue, - ScalarMetricConfidenceValue) -from labelbox.data.annotation_types.metrics.confusion_matrix import ( - ConfusionMatrixAggregation, ConfusionMatrixMetric, - ConfusionMatrixMetricValue, ConfusionMatrixMetricConfidenceValue) - - -class BaseNDMetric(NDJsonBase): - metric_value: float - feature_name: Optional[str] = None - subclass_name: Optional[str] = None - - class Config: - use_enum_values = True - - def dict(self, *args, **kwargs): - res = super().dict(*args, **kwargs) - for field in ['featureName', 'subclassName']: - if res[field] is None: - res.pop(field) - return res - - -class NDConfusionMatrixMetric(BaseNDMetric): - metric_value: Union[ConfusionMatrixMetricValue, - ConfusionMatrixMetricConfidenceValue] - metric_name: str - aggregation: ConfusionMatrixAggregation - - def to_common(self) -> ConfusionMatrixMetric: - return ConfusionMatrixMetric(value=self.metric_value, - metric_name=self.metric_name, - feature_name=self.feature_name, - subclass_name=self.subclass_name, - aggregation=self.aggregation, - extra={'uuid': self.uuid}) - - @classmethod - def from_common( - cls, metric: ConfusionMatrixMetric, - data: Union[TextData, ImageData]) -> "NDConfusionMatrixMetric": - return cls(uuid=metric.extra.get('uuid'), - metric_value=metric.value, - metric_name=metric.metric_name, - feature_name=metric.feature_name, - subclass_name=metric.subclass_name, - aggregation=metric.aggregation, - data_row=DataRow(id=data.uid, global_key=data.global_key)) - - -class NDScalarMetric(BaseNDMetric): - metric_value: Union[ScalarMetricValue, ScalarMetricConfidenceValue] - metric_name: Optional[str] - aggregation: ScalarMetricAggregation = ScalarMetricAggregation.ARITHMETIC_MEAN - - def to_common(self) -> ScalarMetric: - return ScalarMetric(value=self.metric_value, - metric_name=self.metric_name, - feature_name=self.feature_name, - subclass_name=self.subclass_name, - aggregation=self.aggregation, - extra={'uuid': self.uuid}) - - @classmethod - def from_common(cls, metric: ScalarMetric, - data: Union[TextData, ImageData]) -> "NDScalarMetric": - return cls(uuid=metric.extra.get('uuid'), - metric_value=metric.value, - metric_name=metric.metric_name, - feature_name=metric.feature_name, - subclass_name=metric.subclass_name, - aggregation=metric.aggregation.value, - data_row=DataRow(id=data.uid, global_key=data.global_key)) - - def dict(self, *args, **kwargs): - res = super().dict(*args, **kwargs) - # For backwards compatibility. - if res['metricName'] is None: - res.pop('metricName') - res.pop('aggregation') - return res - - -class NDMetricAnnotation: - - @classmethod - def to_common( - cls, annotation: Union[NDScalarMetric, NDConfusionMatrixMetric] - ) -> Union[ScalarMetric, ConfusionMatrixMetric]: - return annotation.to_common() - - @classmethod - def from_common( - cls, annotation: Union[ScalarMetric, - ConfusionMatrixMetric], data: Union[TextData, - ImageData] - ) -> Union[NDScalarMetric, NDConfusionMatrixMetric]: - obj = cls.lookup_object(annotation) - return obj.from_common(annotation, data) - - @staticmethod - def lookup_object( - annotation: Union[ScalarMetric, ConfusionMatrixMetric] - ) -> Union[Type[NDScalarMetric], Type[NDConfusionMatrixMetric]]: - result = { - ScalarMetric: NDScalarMetric, - ConfusionMatrixMetric: NDConfusionMatrixMetric, - }.get(type(annotation)) - if result is None: - raise TypeError( - f"Unable to convert object to MAL format. `{type(annotation)}`") - return result - ----- -labelbox/data/serialization/ndjson/label.py -from itertools import groupby -from operator import itemgetter -from typing import Dict, Generator, List, Tuple, Union -from collections import defaultdict -import warnings - -from labelbox import pydantic_compat - -from ...annotation_types.annotation import ClassificationAnnotation, ObjectAnnotation -from ...annotation_types.relationship import RelationshipAnnotation -from ...annotation_types.video import DICOMObjectAnnotation, VideoClassificationAnnotation -from ...annotation_types.video import VideoObjectAnnotation, VideoMaskAnnotation -from ...annotation_types.collection import LabelCollection, LabelGenerator -from ...annotation_types.data import DicomData, ImageData, TextData, VideoData -from ...annotation_types.label import Label -from ...annotation_types.ner import TextEntity, ConversationEntity -from ...annotation_types.classification import Dropdown -from ...annotation_types.metrics import ScalarMetric, ConfusionMatrixMetric - -from .metric import NDScalarMetric, NDMetricAnnotation, NDConfusionMatrixMetric -from .classification import NDChecklistSubclass, NDClassification, NDClassificationType, NDRadioSubclass -from .objects import NDObject, NDObjectType, NDSegments, NDDicomSegments, NDVideoMasks, NDDicomMasks -from .relationship import NDRelationship -from .base import DataRow - -AnnotationType = Union[NDObjectType, NDClassificationType, - NDConfusionMatrixMetric, NDScalarMetric, NDDicomSegments, - NDSegments, NDDicomMasks, NDVideoMasks, NDRelationship] - - -class NDLabel(pydantic_compat.BaseModel): - annotations: List[AnnotationType] - - class _Relationship(pydantic_compat.BaseModel): - """This object holds information about the relationship""" - ndjson: NDRelationship - source: str - target: str - - class _AnnotationGroup(pydantic_compat.BaseModel): - """Stores all the annotations and relationships per datarow""" - data_row: DataRow = None - ndjson_annotations: Dict[str, AnnotationType] = {} - relationships: List["NDLabel._Relationship"] = [] - - def to_common(self) -> LabelGenerator: - annotation_groups = defaultdict(NDLabel._AnnotationGroup) - - for ndjson_annotation in self.annotations: - key = ndjson_annotation.data_row.id or ndjson_annotation.data_row.global_key - group = annotation_groups[key] - - if isinstance(ndjson_annotation, NDRelationship): - group.relationships.append( - NDLabel._Relationship( - ndjson=ndjson_annotation, - source=ndjson_annotation.relationship.source, - target=ndjson_annotation.relationship.target)) - else: - # if this is the first object in this group, we - # take note of the DataRow this group belongs to - # and store it in the _AnnotationGroupTuple - if not group.ndjson_annotations: - group.data_row = ndjson_annotation.data_row - - # if this assertion fails and it's a valid case, - # we need to change the value type of - # `_AnnotationGroupTuple.ndjson_objects` to accept a list of objects - # and adapt the code to support duplicate UUIDs - assert ndjson_annotation.uuid not in group.ndjson_annotations, f"UUID '{ndjson_annotation.uuid}' is not unique" - - group.ndjson_annotations[ - ndjson_annotation.uuid] = ndjson_annotation - - return LabelGenerator( - data=self._generate_annotations(annotation_groups)) - - @classmethod - def from_common(cls, - data: LabelCollection) -> Generator["NDLabel", None, None]: - for label in data: - yield from cls._create_non_video_annotations(label) - yield from cls._create_video_annotations(label) - - def _generate_annotations( - self, annotation_groups: Dict[str, _AnnotationGroup] - ) -> Generator[Label, None, None]: - for _, group in annotation_groups.items(): - relationship_annotations: Dict[str, ObjectAnnotation] = {} - annotations = [] - # first, we iterate through all the NDJSON objects and store the - # deserialized objects in the _AnnotationGroupTuple - # object *if* the object can be used in a relationship - for uuid, ndjson_annotation in group.ndjson_annotations.items(): - if isinstance(ndjson_annotation, NDDicomSegments): - annotations.extend( - NDDicomSegments.to_common(ndjson_annotation, - ndjson_annotation.name, - ndjson_annotation.schema_id)) - elif isinstance(ndjson_annotation, NDSegments): - annotations.extend( - NDSegments.to_common(ndjson_annotation, - ndjson_annotation.name, - ndjson_annotation.schema_id)) - elif isinstance(ndjson_annotation, NDDicomMasks): - annotations.append( - NDDicomMasks.to_common(ndjson_annotation)) - elif isinstance(ndjson_annotation, NDVideoMasks): - annotations.append( - NDVideoMasks.to_common(ndjson_annotation)) - elif isinstance(ndjson_annotation, NDObjectType.__args__): - annotation = NDObject.to_common(ndjson_annotation) - annotations.append(annotation) - relationship_annotations[uuid] = annotation - elif isinstance(ndjson_annotation, - NDClassificationType.__args__): - annotations.extend( - NDClassification.to_common(ndjson_annotation)) - elif isinstance(ndjson_annotation, - (NDScalarMetric, NDConfusionMatrixMetric)): - annotations.append( - NDMetricAnnotation.to_common(ndjson_annotation)) - else: - raise TypeError( - f"Unsupported annotation. {type(ndjson_annotation)}") - - # after all the annotations have been discovered, we can now create - # the relationship objects and use references to the objects - # involved - for relationship in group.relationships: - try: - source, target = relationship_annotations[ - relationship.source], relationship_annotations[ - relationship.target] - except KeyError: - raise ValueError( - f"Relationship object refers to nonexistent object with UUID '{relationship.source}' and/or '{relationship.target}'" - ) - annotations.append( - NDRelationship.to_common(relationship.ndjson, source, - target)) - - yield Label(annotations=annotations, - data=self._infer_media_type(group.data_row, - annotations)) - - def _infer_media_type( - self, data_row: DataRow, - annotations: List[Union[TextEntity, ConversationEntity, - VideoClassificationAnnotation, - DICOMObjectAnnotation, VideoObjectAnnotation, - ObjectAnnotation, ClassificationAnnotation, - ScalarMetric, ConfusionMatrixMetric]] - ) -> Union[TextData, VideoData, ImageData]: - if len(annotations) == 0: - raise ValueError("Missing annotations while inferring media type") - - types = {type(annotation) for annotation in annotations} - data = ImageData - if (TextEntity in types) or (ConversationEntity in types): - data = TextData - elif VideoClassificationAnnotation in types or VideoObjectAnnotation in types: - data = VideoData - elif DICOMObjectAnnotation in types: - data = DicomData - - if data_row.id: - return data(uid=data_row.id) - else: - return data(global_key=data_row.global_key) - - @staticmethod - def _get_consecutive_frames( - frames_indices: List[int]) -> List[Tuple[int, int]]: - consecutive = [] - for k, g in groupby(enumerate(frames_indices), lambda x: x[0] - x[1]): - group = list(map(itemgetter(1), g)) - consecutive.append((group[0], group[-1])) - return consecutive - - @classmethod - def _get_segment_frame_ranges( - cls, annotation_group: List[Union[VideoClassificationAnnotation, - VideoObjectAnnotation]] - ) -> List[Tuple[int, int]]: - sorted_frame_segment_indices = sorted([ - (annotation.frame, annotation.segment_index) - for annotation in annotation_group - if annotation.segment_index is not None - ]) - if len(sorted_frame_segment_indices) == 0: - # Group segment by consecutive frames, since `segment_index` is not present - return cls._get_consecutive_frames( - sorted([annotation.frame for annotation in annotation_group])) - elif len(sorted_frame_segment_indices) == len(annotation_group): - # Group segment by segment_index - last_segment_id = 0 - segment_groups = defaultdict(list) - for frame, segment_index in sorted_frame_segment_indices: - if segment_index < last_segment_id: - raise ValueError( - f"`segment_index` must be in ascending order. Please investigate video annotation at frame, '{frame}'" - ) - segment_groups[segment_index].append(frame) - last_segment_id = segment_index - frame_ranges = [] - for group in segment_groups.values(): - frame_ranges.append((group[0], group[-1])) - return frame_ranges - else: - raise ValueError( - f"Video annotations cannot partially have `segment_index` set") - - @classmethod - def _create_video_annotations( - cls, label: Label - ) -> Generator[Union[NDChecklistSubclass, NDRadioSubclass], None, None]: - - video_annotations = defaultdict(list) - for annot in label.annotations: - if isinstance( - annot, - (VideoClassificationAnnotation, VideoObjectAnnotation)): - video_annotations[annot.feature_schema_id or - annot.name].append(annot) - elif isinstance(annot, VideoMaskAnnotation): - yield NDObject.from_common(annotation=annot, data=label.data) - - for annotation_group in video_annotations.values(): - segment_frame_ranges = cls._get_segment_frame_ranges( - annotation_group) - if isinstance(annotation_group[0], VideoClassificationAnnotation): - annotation = annotation_group[0] - frames_data = [] - for frames in segment_frame_ranges: - frames_data.append({'start': frames[0], 'end': frames[-1]}) - annotation.extra.update({'frames': frames_data}) - yield NDClassification.from_common(annotation, label.data) - - elif isinstance(annotation_group[0], VideoObjectAnnotation): - segments = [] - for start_frame, end_frame in segment_frame_ranges: - segment = [] - for annotation in annotation_group: - if annotation.keyframe and start_frame <= annotation.frame <= end_frame: - segment.append(annotation) - segments.append(segment) - yield NDObject.from_common(segments, label.data) - - @classmethod - def _create_non_video_annotations(cls, label: Label): - non_video_annotations = [ - annot for annot in label.annotations - if not isinstance(annot, (VideoClassificationAnnotation, - VideoObjectAnnotation, - VideoMaskAnnotation)) - ] - for annotation in non_video_annotations: - if isinstance(annotation, ClassificationAnnotation): - if isinstance(annotation.value, Dropdown): - raise ValueError( - "Dropdowns are not supported by the NDJson format." - " Please filter out Dropdown annotations before converting." - ) - yield NDClassification.from_common(annotation, label.data) - elif isinstance(annotation, ObjectAnnotation): - yield NDObject.from_common(annotation, label.data) - elif isinstance(annotation, (ScalarMetric, ConfusionMatrixMetric)): - yield NDMetricAnnotation.from_common(annotation, label.data) - elif isinstance(annotation, RelationshipAnnotation): - yield NDRelationship.from_common(annotation, label.data) - else: - raise TypeError( - f"Unable to convert object to MAL format. `{type(getattr(annotation, 'value',annotation))}`" - ) - ----- -labelbox/data/serialization/ndjson/objects.py -from io import BytesIO -from typing import Any, Dict, List, Tuple, Union, Optional -import base64 - -from labelbox.data.annotation_types.ner.conversation_entity import ConversationEntity -from labelbox.data.annotation_types.video import VideoObjectAnnotation, DICOMObjectAnnotation -from labelbox.data.mixins import ConfidenceMixin, CustomMetricsMixin, CustomMetric, CustomMetricsNotSupportedMixin -import numpy as np - -from labelbox import pydantic_compat -from PIL import Image -from labelbox.data.annotation_types import feature - -from labelbox.data.annotation_types.data.video import VideoData - -from ...annotation_types.data import ImageData, TextData, MaskData -from ...annotation_types.ner import DocumentEntity, DocumentTextSelection, TextEntity -from ...annotation_types.types import Cuid -from ...annotation_types.geometry import DocumentRectangle, Rectangle, Polygon, Line, Point, Mask -from ...annotation_types.annotation import ClassificationAnnotation, ObjectAnnotation -from ...annotation_types.video import VideoMaskAnnotation, DICOMMaskAnnotation, MaskFrame, MaskInstance -from .classification import NDClassification, NDSubclassification, NDSubclassificationType -from .base import DataRow, NDAnnotation, NDJsonBase - - -class NDBaseObject(NDAnnotation): - classifications: List[NDSubclassificationType] = [] - - -class VideoSupported(pydantic_compat.BaseModel): - # support for video for objects are per-frame basis - frame: int - - -class DicomSupported(pydantic_compat.BaseModel): - group_key: str - - -class _Point(pydantic_compat.BaseModel): - x: float - y: float - - -class Bbox(pydantic_compat.BaseModel): - top: float - left: float - height: float - width: float - - -class NDPoint(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): - point: _Point - - def to_common(self) -> Point: - return Point(x=self.point.x, y=self.point.y) - - @classmethod - def from_common( - cls, - uuid: str, - point: Point, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None) -> "NDPoint": - return cls(point={ - 'x': point.x, - 'y': point.y - }, - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) - - -class NDFramePoint(VideoSupported): - point: _Point - classifications: List[NDSubclassificationType] = [] - - def to_common(self, name: str, feature_schema_id: Cuid, - segment_index: int) -> VideoObjectAnnotation: - return VideoObjectAnnotation(frame=self.frame, - segment_index=segment_index, - keyframe=True, - name=name, - feature_schema_id=feature_schema_id, - value=Point(x=self.point.x, - y=self.point.y), - classifications=[ - NDSubclassification.to_common(annot) - for annot in self.classifications - ]) - - @classmethod - def from_common( - cls, - frame: int, - point: Point, - classifications: List[NDSubclassificationType], - ): - return cls(frame=frame, - point=_Point(x=point.x, y=point.y), - classifications=classifications) - - -class NDLine(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): - line: List[_Point] - - def to_common(self) -> Line: - return Line(points=[Point(x=pt.x, y=pt.y) for pt in self.line]) - - @classmethod - def from_common( - cls, - uuid: str, - line: Line, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None) -> "NDLine": - return cls(line=[{ - 'x': pt.x, - 'y': pt.y - } for pt in line.points], - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) - - -class NDFrameLine(VideoSupported): - line: List[_Point] - classifications: List[NDSubclassificationType] = [] - - def to_common(self, name: str, feature_schema_id: Cuid, - segment_index: int) -> VideoObjectAnnotation: - return VideoObjectAnnotation( - frame=self.frame, - segment_index=segment_index, - keyframe=True, - name=name, - feature_schema_id=feature_schema_id, - value=Line(points=[Point(x=pt.x, y=pt.y) for pt in self.line]), - classifications=[ - NDSubclassification.to_common(annot) - for annot in self.classifications - ]) - - @classmethod - def from_common( - cls, - frame: int, - line: Line, - classifications: List[NDSubclassificationType], - ): - return cls(frame=frame, - line=[{ - 'x': pt.x, - 'y': pt.y - } for pt in line.points], - classifications=classifications) - - -class NDDicomLine(NDFrameLine): - - def to_common(self, name: str, feature_schema_id: Cuid, segment_index: int, - group_key: str) -> DICOMObjectAnnotation: - return DICOMObjectAnnotation( - frame=self.frame, - segment_index=segment_index, - keyframe=True, - name=name, - feature_schema_id=feature_schema_id, - value=Line(points=[Point(x=pt.x, y=pt.y) for pt in self.line]), - group_key=group_key) - - -class NDPolygon(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): - polygon: List[_Point] - - def to_common(self) -> Polygon: - return Polygon(points=[Point(x=pt.x, y=pt.y) for pt in self.polygon]) - - @classmethod - def from_common( - cls, - uuid: str, - polygon: Polygon, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None) -> "NDPolygon": - return cls(polygon=[{ - 'x': pt.x, - 'y': pt.y - } for pt in polygon.points], - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) - - -class NDRectangle(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): - bbox: Bbox - - def to_common(self) -> Rectangle: - return Rectangle(start=Point(x=self.bbox.left, y=self.bbox.top), - end=Point(x=self.bbox.left + self.bbox.width, - y=self.bbox.top + self.bbox.height)) - - @classmethod - def from_common( - cls, - uuid: str, - rectangle: Rectangle, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None - ) -> "NDRectangle": - return cls(bbox=Bbox(top=min(rectangle.start.y, rectangle.end.y), - left=min(rectangle.start.x, rectangle.end.x), - height=abs(rectangle.end.y - rectangle.start.y), - width=abs(rectangle.end.x - rectangle.start.x)), - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - page=extra.get('page'), - unit=extra.get('unit'), - confidence=confidence, - custom_metrics=custom_metrics) - - -class NDDocumentRectangle(NDRectangle): - page: int - unit: str - - def to_common(self) -> DocumentRectangle: - return DocumentRectangle(start=Point(x=self.bbox.left, y=self.bbox.top), - end=Point(x=self.bbox.left + self.bbox.width, - y=self.bbox.top + self.bbox.height), - page=self.page, - unit=self.unit) - - @classmethod - def from_common( - cls, - uuid: str, - rectangle: DocumentRectangle, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None - ) -> "NDRectangle": - return cls(bbox=Bbox(top=min(rectangle.start.y, rectangle.end.y), - left=min(rectangle.start.x, rectangle.end.x), - height=abs(rectangle.end.y - rectangle.start.y), - width=abs(rectangle.end.x - rectangle.start.x)), - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - page=rectangle.page, - unit=rectangle.unit.value, - confidence=confidence, - custom_metrics=custom_metrics) - - -class NDFrameRectangle(VideoSupported): - bbox: Bbox - classifications: List[NDSubclassificationType] = [] - - def to_common(self, name: str, feature_schema_id: Cuid, - segment_index: int) -> VideoObjectAnnotation: - return VideoObjectAnnotation( - frame=self.frame, - segment_index=segment_index, - keyframe=True, - name=name, - feature_schema_id=feature_schema_id, - value=Rectangle(start=Point(x=self.bbox.left, y=self.bbox.top), - end=Point(x=self.bbox.left + self.bbox.width, - y=self.bbox.top + self.bbox.height)), - classifications=[ - NDSubclassification.to_common(annot) - for annot in self.classifications - ]) - - @classmethod - def from_common( - cls, - frame: int, - rectangle: Rectangle, - classifications: List[NDSubclassificationType], - ): - return cls(frame=frame, - bbox=Bbox(top=min(rectangle.start.y, rectangle.end.y), - left=min(rectangle.start.x, rectangle.end.x), - height=abs(rectangle.end.y - rectangle.start.y), - width=abs(rectangle.end.x - rectangle.start.x)), - classifications=classifications) - - -class NDSegment(pydantic_compat.BaseModel): - keyframes: List[Union[NDFrameRectangle, NDFramePoint, NDFrameLine]] - - @staticmethod - def lookup_segment_object_type(segment: List) -> "NDFrameObjectType": - """Used for determining which object type the annotation contains - returns the object type""" - result = { - Rectangle: NDFrameRectangle, - Point: NDFramePoint, - Line: NDFrameLine, - }.get(type(segment[0].value)) - return result - - @staticmethod - def segment_with_uuid(keyframe: Union[NDFrameRectangle, NDFramePoint, - NDFrameLine], uuid: str): - keyframe._uuid = uuid - keyframe.extra = {'uuid': uuid} - return keyframe - - def to_common(self, name: str, feature_schema_id: Cuid, uuid: str, - segment_index: int): - return [ - self.segment_with_uuid( - keyframe.to_common(name=name, - feature_schema_id=feature_schema_id, - segment_index=segment_index), uuid) - for keyframe in self.keyframes - ] - - @classmethod - def from_common(cls, segment): - nd_frame_object_type = cls.lookup_segment_object_type(segment) - - return cls(keyframes=[ - nd_frame_object_type.from_common( - object_annotation.frame, object_annotation.value, [ - NDSubclassification.from_common(annot) - for annot in object_annotation.classifications - ]) - for object_annotation in segment - ]) - - -class NDDicomSegment(NDSegment): - keyframes: List[NDDicomLine] - - @staticmethod - def lookup_segment_object_type(segment: List) -> "NDDicomObjectType": - """Used for determining which object type the annotation contains - returns the object type""" - segment_class = type(segment[0].value) - if segment_class == Line: - return NDDicomLine - else: - raise ValueError('DICOM segments only support Line objects') - - def to_common(self, name: str, feature_schema_id: Cuid, uuid: str, - segment_index: int, group_key: str): - return [ - self.segment_with_uuid( - keyframe.to_common(name=name, - feature_schema_id=feature_schema_id, - segment_index=segment_index, - group_key=group_key), uuid) - for keyframe in self.keyframes - ] - - -class NDSegments(NDBaseObject): - segments: List[NDSegment] - - def to_common(self, name: str, feature_schema_id: Cuid): - result = [] - for idx, segment in enumerate(self.segments): - result.extend( - segment.to_common(name=name, - feature_schema_id=feature_schema_id, - segment_index=idx, - uuid=self.uuid)) - return result - - @classmethod - def from_common(cls, segments: List[VideoObjectAnnotation], data: VideoData, - name: str, feature_schema_id: Cuid, - extra: Dict[str, Any]) -> "NDSegments": - - segments = [NDSegment.from_common(segment) for segment in segments] - - return cls(segments=segments, - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=extra.get('uuid')) - - -class NDDicomSegments(NDBaseObject, DicomSupported): - segments: List[NDDicomSegment] - - def to_common(self, name: str, feature_schema_id: Cuid): - result = [] - for idx, segment in enumerate(self.segments): - result.extend( - segment.to_common(name=name, - feature_schema_id=feature_schema_id, - segment_index=idx, - uuid=self.uuid, - group_key=self.group_key)) - return result - - @classmethod - def from_common(cls, segments: List[DICOMObjectAnnotation], data: VideoData, - name: str, feature_schema_id: Cuid, extra: Dict[str, Any], - group_key: str) -> "NDDicomSegments": - - segments = [NDDicomSegment.from_common(segment) for segment in segments] - - return cls(segments=segments, - dataRow=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=extra.get('uuid'), - group_key=group_key) - - -class _URIMask(pydantic_compat.BaseModel): - instanceURI: str - colorRGB: Tuple[int, int, int] - - -class _PNGMask(pydantic_compat.BaseModel): - png: str - - -class NDMask(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): - mask: Union[_URIMask, _PNGMask] - - def to_common(self) -> Mask: - if isinstance(self.mask, _URIMask): - return Mask(mask=MaskData(url=self.mask.instanceURI), - color=self.mask.colorRGB) - else: - encoded_image_bytes = self.mask.png.encode('utf-8') - image_bytes = base64.b64decode(encoded_image_bytes) - image = np.array(Image.open(BytesIO(image_bytes))) - if np.max(image) > 1: - raise ValueError( - f"Expected binary mask. Found max value of {np.max(image)}") - # Color is 1,1,1 because it is a binary array and we are just stacking it into 3 channels - return Mask(mask=MaskData.from_2D_arr(image), color=(1, 1, 1)) - - @classmethod - def from_common( - cls, - uuid: str, - mask: Mask, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None) -> "NDMask": - - if mask.mask.url is not None: - lbv1_mask = _URIMask(instanceURI=mask.mask.url, colorRGB=mask.color) - else: - binary = np.all(mask.mask.value == mask.color, axis=-1) - im_bytes = BytesIO() - Image.fromarray(binary, 'L').save(im_bytes, format="PNG") - lbv1_mask = _PNGMask( - png=base64.b64encode(im_bytes.getvalue()).decode('utf-8')) - - return cls(mask=lbv1_mask, - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) - - -class NDVideoMasksFramesInstances(pydantic_compat.BaseModel): - frames: List[MaskFrame] - instances: List[MaskInstance] - - -class NDVideoMasks(NDJsonBase, ConfidenceMixin, CustomMetricsNotSupportedMixin): - masks: NDVideoMasksFramesInstances - - def to_common(self) -> VideoMaskAnnotation: - for mask_frame in self.masks.frames: - if mask_frame.im_bytes: - mask_frame.im_bytes = base64.b64decode( - mask_frame.im_bytes.encode('utf-8')) - - return VideoMaskAnnotation( - frames=self.masks.frames, - instances=self.masks.instances, - ) - - @classmethod - def from_common(cls, annotation, data): - for mask_frame in annotation.frames: - if mask_frame.im_bytes: - mask_frame.im_bytes = base64.b64encode( - mask_frame.im_bytes).decode('utf-8') - - return cls( - data_row=DataRow(id=data.uid, global_key=data.global_key), - masks=NDVideoMasksFramesInstances(frames=annotation.frames, - instances=annotation.instances), - ) - - -class NDDicomMasks(NDVideoMasks, DicomSupported): - - def to_common(self) -> DICOMMaskAnnotation: - return DICOMMaskAnnotation( - frames=self.masks.frames, - instances=self.masks.instances, - group_key=self.group_key, - ) - - @classmethod - def from_common(cls, annotation, data): - return cls( - data_row=DataRow(id=data.uid, global_key=data.global_key), - masks=NDVideoMasksFramesInstances(frames=annotation.frames, - instances=annotation.instances), - group_key=annotation.group_key.value, - ) - - -class Location(pydantic_compat.BaseModel): - start: int - end: int - - -class NDTextEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): - location: Location - - def to_common(self) -> TextEntity: - return TextEntity(start=self.location.start, end=self.location.end) - - @classmethod - def from_common( - cls, - uuid: str, - text_entity: TextEntity, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None - ) -> "NDTextEntity": - return cls(location=Location( - start=text_entity.start, - end=text_entity.end, - ), - data_row=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) - - -class NDDocumentEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): - name: str - text_selections: List[DocumentTextSelection] - - def to_common(self) -> DocumentEntity: - return DocumentEntity(name=self.name, - text_selections=self.text_selections) - - @classmethod - def from_common( - cls, - uuid: str, - document_entity: DocumentEntity, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None - ) -> "NDDocumentEntity": - - return cls(text_selections=document_entity.text_selections, - dataRow=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) - - -class NDConversationEntity(NDTextEntity): - message_id: str - - def to_common(self) -> ConversationEntity: - return ConversationEntity(start=self.location.start, - end=self.location.end, - message_id=self.message_id) - - @classmethod - def from_common( - cls, - uuid: str, - conversation_entity: ConversationEntity, - classifications: List[ClassificationAnnotation], - name: str, - feature_schema_id: Cuid, - extra: Dict[str, Any], - data: Union[ImageData, TextData], - confidence: Optional[float] = None, - custom_metrics: Optional[List[CustomMetric]] = None - ) -> "NDConversationEntity": - return cls(location=Location(start=conversation_entity.start, - end=conversation_entity.end), - message_id=conversation_entity.message_id, - dataRow=DataRow(id=data.uid, global_key=data.global_key), - name=name, - schema_id=feature_schema_id, - uuid=uuid, - classifications=classifications, - confidence=confidence, - custom_metrics=custom_metrics) - - -class NDObject: - - @staticmethod - def to_common(annotation: "NDObjectType") -> ObjectAnnotation: - common_annotation = annotation.to_common() - classifications = [ - NDSubclassification.to_common(annot) - for annot in annotation.classifications - ] - confidence = annotation.confidence if hasattr(annotation, - 'confidence') else None - - custom_metrics = annotation.custom_metrics if hasattr( - annotation, 'custom_metrics') else None - return ObjectAnnotation(value=common_annotation, - name=annotation.name, - feature_schema_id=annotation.schema_id, - classifications=classifications, - extra={ - 'uuid': annotation.uuid, - 'page': annotation.page, - 'unit': annotation.unit - }, - confidence=confidence, - custom_metrics=custom_metrics) - - @classmethod - def from_common( - cls, - annotation: Union[ObjectAnnotation, List[List[VideoObjectAnnotation]], - VideoMaskAnnotation], data: Union[ImageData, TextData] - ) -> Union[NDLine, NDPoint, NDPolygon, NDDocumentRectangle, NDRectangle, - NDMask, NDTextEntity]: - obj = cls.lookup_object(annotation) - - # if it is video segments - if (obj == NDSegments or obj == NDDicomSegments): - - first_video_annotation = annotation[0][0] - args = dict( - segments=annotation, - data=data, - name=first_video_annotation.name, - feature_schema_id=first_video_annotation.feature_schema_id, - extra=first_video_annotation.extra) - - if isinstance(first_video_annotation, DICOMObjectAnnotation): - group_key = first_video_annotation.group_key.value - args.update(dict(group_key=group_key)) - - return obj.from_common(**args) - elif (obj == NDVideoMasks or obj == NDDicomMasks): - return obj.from_common(annotation, data) - - subclasses = [ - NDSubclassification.from_common(annot) - for annot in annotation.classifications - ] - optional_kwargs = {} - if (annotation.confidence): - optional_kwargs['confidence'] = annotation.confidence - - if (annotation.custom_metrics): - optional_kwargs['custom_metrics'] = annotation.custom_metrics - - return obj.from_common(str(annotation._uuid), annotation.value, - subclasses, annotation.name, - annotation.feature_schema_id, annotation.extra, - data, **optional_kwargs) - - @staticmethod - def lookup_object( - annotation: Union[ObjectAnnotation, List]) -> "NDObjectType": - - if isinstance(annotation, DICOMMaskAnnotation): - result = NDDicomMasks - elif isinstance(annotation, VideoMaskAnnotation): - result = NDVideoMasks - elif isinstance(annotation, list): - try: - first_annotation = annotation[0][0] - except IndexError: - raise ValueError("Annotation list cannot be empty") - - if isinstance(first_annotation, DICOMObjectAnnotation): - result = NDDicomSegments - else: - result = NDSegments - else: - result = { - Line: NDLine, - Point: NDPoint, - Polygon: NDPolygon, - Rectangle: NDRectangle, - DocumentRectangle: NDDocumentRectangle, - Mask: NDMask, - TextEntity: NDTextEntity, - DocumentEntity: NDDocumentEntity, - ConversationEntity: NDConversationEntity, - }.get(type(annotation.value)) - if result is None: - raise TypeError( - f"Unable to convert object to MAL format. `{type(annotation.value)}`" - ) - return result - - -# NOTE: Deserialization of subclasses in pydantic is a known PIA, see here https://blog.devgenius.io/deserialize-child-classes-with-pydantic-that-gonna-work-784230e1cf83 -# I could implement the registry approach suggested there, but I found that if I list subclass (that has more attributes) before the parent class, it works -# This is a bit of a hack, but it works for now -NDEntityType = Union[NDConversationEntity, NDTextEntity] -NDObjectType = Union[NDLine, NDPolygon, NDPoint, NDDocumentRectangle, - NDRectangle, NDMask, NDEntityType, NDDocumentEntity] - -NDFrameObjectType = NDFrameRectangle, NDFramePoint, NDFrameLine -NDDicomObjectType = NDDicomLine - ----- -labelbox/data/serialization/ndjson/base.py -from typing import Optional -from uuid import uuid4 - -from labelbox.utils import _CamelCaseMixin, is_exactly_one_set -from labelbox import pydantic_compat -from ...annotation_types.types import Cuid - - -class DataRow(_CamelCaseMixin): - id: str = None - global_key: str = None - - @pydantic_compat.root_validator() - def must_set_one(cls, values): - if not is_exactly_one_set(values.get('id'), values.get('global_key')): - raise ValueError("Must set either id or global_key") - return values - - -class NDJsonBase(_CamelCaseMixin): - uuid: str = None - data_row: DataRow - - @pydantic_compat.validator('uuid', pre=True, always=True) - def set_id(cls, v): - return v or str(uuid4()) - - def dict(self, *args, **kwargs): - """ Pop missing id or missing globalKey from dataRow """ - res = super().dict(*args, **kwargs) - if not self.data_row.id: - res['dataRow'].pop('id') - if not self.data_row.global_key: - res['dataRow'].pop('globalKey') - return res - - -class NDAnnotation(NDJsonBase): - name: Optional[str] = None - schema_id: Optional[Cuid] = None - message_id: Optional[str] = None - page: Optional[int] = None - unit: Optional[str] = None - - @pydantic_compat.root_validator() - def must_set_one(cls, values): - if ('schema_id' not in values or values['schema_id'] - is None) and ('name' not in values or values['name'] is None): - raise ValueError("Schema id or name are not set. Set either one.") - return values - - def dict(self, *args, **kwargs): - res = super().dict(*args, **kwargs) - if 'name' in res and res['name'] is None: - res.pop('name') - if 'schemaId' in res and res['schemaId'] is None: - res.pop('schemaId') - return res - ----- -labelbox/data/serialization/coco/annotation.py -from typing import Tuple, List, Union -from pathlib import Path -from collections import defaultdict - -from labelbox import pydantic_compat -import numpy as np - -from .path import PathSerializerMixin - - -def rle_decoding(rle_arr: List[int], w: int, h: int) -> np.ndarray: - indices = [] - for idx, cnt in zip(rle_arr[0::2], rle_arr[1::2]): - indices.extend(list(range(idx - 1, - idx + cnt - 1))) # RLE is 1-based index - mask = np.zeros(h * w, dtype=np.uint8) - mask[indices] = 1 - return mask.reshape((w, h)).T - - -def get_annotation_lookup(annotations): - annotation_lookup = defaultdict(list) - for annotation in annotations: - annotation_lookup[getattr(annotation, 'image_id', None) or - getattr(annotation, 'name')].append(annotation) - return annotation_lookup - - -class SegmentInfo(pydantic_compat.BaseModel): - id: int - category_id: int - area: int - bbox: Tuple[float, float, float, float] #[x,y,w,h], - iscrowd: int = 0 - - -class RLE(pydantic_compat.BaseModel): - counts: List[int] - size: Tuple[int, int] # h,w or w,h? - - -class COCOObjectAnnotation(pydantic_compat.BaseModel): - # All segmentations for a particular class in an image... - # So each image will have one of these for each class present in the image.. - # Annotations only exist if there is data.. - id: int - image_id: int - category_id: int - segmentation: Union[RLE, List[List[float]]] # [[x1,y1,x2,y2,x3,y3...]] - area: float - bbox: Tuple[float, float, float, float] #[x,y,w,h], - iscrowd: int = 0 - - -class PanopticAnnotation(PathSerializerMixin): - # One to one relationship between image and panoptic annotation - image_id: int - file_name: Path - segments_info: List[SegmentInfo] - ----- -labelbox/data/serialization/coco/panoptic_dataset.py -from concurrent.futures import ProcessPoolExecutor, as_completed -from typing import Dict, Any, List, Union -from pathlib import Path - -from labelbox import pydantic_compat -from tqdm import tqdm -import numpy as np -from PIL import Image - -from ...annotation_types.geometry import Polygon, Rectangle -from ...annotation_types import Label -from ...annotation_types.geometry.mask import Mask -from ...annotation_types.annotation import ObjectAnnotation -from ...annotation_types.data.raster import MaskData, ImageData -from ...annotation_types.collection import LabelCollection -from .categories import Categories, hash_category_name -from .image import CocoImage, get_image, get_image_id, id_to_rgb -from .annotation import PanopticAnnotation, SegmentInfo, get_annotation_lookup - - -def vector_to_coco_segment_info(canvas: np.ndarray, - annotation: ObjectAnnotation, - annotation_idx: int, image: CocoImage, - category_id: int): - - shapely = annotation.value.shapely - if shapely.is_empty: - return - - xmin, ymin, xmax, ymax = shapely.bounds - canvas = annotation.value.draw(height=image.height, - width=image.width, - canvas=canvas, - color=id_to_rgb(annotation_idx)) - - return SegmentInfo(id=annotation_idx, - category_id=category_id, - area=shapely.area, - bbox=[xmin, ymin, xmax - xmin, ymax - ymin]), canvas - - -def mask_to_coco_segment_info(canvas: np.ndarray, annotation, - annotation_idx: int, category_id): - color = id_to_rgb(annotation_idx) - mask = annotation.value.draw(color=color) - shapely = annotation.value.shapely - if shapely.is_empty: - return - - xmin, ymin, xmax, ymax = shapely.bounds - canvas = np.where(canvas == (0, 0, 0), mask, canvas) - return SegmentInfo(id=annotation_idx, - category_id=category_id, - area=shapely.area, - bbox=[xmin, ymin, xmax - xmin, ymax - ymin]), canvas - - -def process_label(label: Label, - idx: Union[int, str], - image_root, - mask_root, - all_stuff=False): - """ - Masks become stuff - Polygon and rectangle become thing - """ - annotations = get_annotation_lookup(label.annotations) - image_id = get_image_id(label, idx) - image = get_image(label, image_root, image_id) - canvas = np.zeros((image.height, image.width, 3)) - - segments = [] - categories = {} - is_thing = {} - - for class_idx, class_name in enumerate(annotations): - for annotation_idx, annotation in enumerate(annotations[class_name]): - categories[annotation.name] = hash_category_name(annotation.name) - if isinstance(annotation.value, Mask): - coco_segment_info = mask_to_coco_segment_info( - canvas, annotation, class_idx + 1, - categories[annotation.name]) - - if coco_segment_info is None: - # Filter out empty masks - continue - - segment, canvas = coco_segment_info - segments.append(segment) - is_thing[annotation.name] = 0 - - elif isinstance(annotation.value, (Polygon, Rectangle)): - coco_vector_info = vector_to_coco_segment_info( - canvas, - annotation, - annotation_idx=(class_idx if all_stuff else annotation_idx) - + 1, - image=image, - category_id=categories[annotation.name]) - - if coco_vector_info is None: - # Filter out empty annotations - continue - - segment, canvas = coco_vector_info - segments.append(segment) - is_thing[annotation.name] = 1 - int(all_stuff) - - mask_file = str(image.file_name).replace('.jpg', '.png') - mask_file = Path(mask_root, mask_file) - Image.fromarray(canvas.astype(np.uint8)).save(mask_file) - return image, PanopticAnnotation( - image_id=image_id, - file_name=Path(mask_file.name), - segments_info=segments), categories, is_thing - - -class CocoPanopticDataset(pydantic_compat.BaseModel): - info: Dict[str, Any] = {} - images: List[CocoImage] - annotations: List[PanopticAnnotation] - categories: List[Categories] - - @classmethod - def from_common(cls, - labels: LabelCollection, - image_root, - mask_root, - all_stuff, - max_workers=8): - all_coco_annotations = [] - coco_categories = {} - coco_things = {} - images = [] - - if max_workers: - with ProcessPoolExecutor(max_workers=max_workers) as exc: - futures = [ - exc.submit(process_label, label, idx, image_root, mask_root, - all_stuff) for idx, label in enumerate(labels) - ] - results = [ - future.result() for future in tqdm(as_completed(futures)) - ] - else: - results = [ - process_label(label, idx, image_root, mask_root, all_stuff) - for idx, label in enumerate(labels) - ] - - for result in results: - images.append(result[0]) - all_coco_annotations.append(result[1]) - coco_categories.update(result[2]) - coco_things.update(result[3]) - - category_mapping = { - category_id: idx + 1 - for idx, category_id in enumerate(coco_categories.values()) - } - categories = [ - Categories(id=category_mapping[idx], - name=name, - supercategory='all', - isthing=coco_things.get(name, 1)) - for name, idx in coco_categories.items() - ] - - for annot in all_coco_annotations: - for segment in annot.segments_info: - segment.category_id = category_mapping[segment.category_id] - - return CocoPanopticDataset(info={ - 'image_root': image_root, - 'mask_root': mask_root - }, - images=images, - annotations=all_coco_annotations, - categories=categories) - - def to_common(self, image_root: Path, mask_root: Path): - category_lookup = { - category.id: category for category in self.categories - } - annotation_lookup = { - annotation.image_id: annotation for annotation in self.annotations - } - for image in self.images: - annotations = [] - annotation = annotation_lookup[image.id] - - im_path = Path(image_root, image.file_name) - if not im_path.exists(): - raise ValueError( - f"Cannot find file {im_path}. Make sure `image_root` is set properly" - ) - if not str(annotation.file_name).endswith('.png'): - raise ValueError( - f"COCO masks must be stored as png files and their extension must be `.png`. Found {annotation.file_name}" - ) - mask = MaskData( - file_path=str(Path(mask_root, annotation.file_name))) - - for segmentation in annotation.segments_info: - category = category_lookup[segmentation.category_id] - annotations.append( - ObjectAnnotation(name=category.name, - value=Mask(mask=mask, - color=id_to_rgb( - segmentation.id)))) - data = ImageData(file_path=str(im_path)) - yield Label(data=data, annotations=annotations) - del annotation_lookup[image.id] - ----- -labelbox/data/serialization/coco/categories.py -import sys -from hashlib import md5 - -from labelbox import pydantic_compat - - -class Categories(pydantic_compat.BaseModel): - id: int - name: str - supercategory: str - isthing: int = 1 - - -def hash_category_name(name: str) -> int: - return int.from_bytes( - md5(name.encode('utf-8')).hexdigest().encode('utf-8'), 'little') - ----- -labelbox/data/serialization/coco/converter.py -from typing import Dict, Any, Union -from pathlib import Path -import os - -from labelbox.data.annotation_types.collection import LabelCollection, LabelGenerator -from labelbox.data.serialization.coco.instance_dataset import CocoInstanceDataset -from labelbox.data.serialization.coco.panoptic_dataset import CocoPanopticDataset - - -def create_path_if_not_exists(path: Union[Path, str], - ignore_existing_data=False): - path = Path(path) - if not path.exists(): - path.mkdir(parents=True, exist_ok=True) - elif not ignore_existing_data and os.listdir(path): - raise ValueError( - f"Directory `{path}`` must be empty. Or set `ignore_existing_data=True`" - ) - return path - - -def validate_path(path: Union[Path, str], name: str): - path = Path(path) - if not path.exists(): - raise ValueError(f"{name} `{path}` must exist") - return path - - -class COCOConverter: - """ - Class for convertering between coco and labelbox formats - Note that this class is only compatible with image data. - - Subclasses are currently ignored. - To use subclasses, manually flatten them before using the converter. - """ - - @staticmethod - def serialize_instances(labels: LabelCollection, - image_root: Union[Path, str], - ignore_existing_data=False, - max_workers=8) -> Dict[str, Any]: - """ - Convert a Labelbox LabelCollection into an mscoco dataset. - This function will only convert masks, polygons, and rectangles. - Masks will be converted into individual instances. - Use deserialize_panoptic to prevent masks from being split apart. - - Args: - labels: A collection of labels to convert - image_root: Where to save images to - ignore_existing_data: Whether or not to raise an exception if images already exist. - This exists only to support detectons panoptic fpn model which requires two mscoco payloads for the same images. - max_workers : Number of workers to process dataset with. A value of 0 will process all data in the main process - Returns: - A dictionary containing labels in the coco object format. - """ - image_root = create_path_if_not_exists(image_root, ignore_existing_data) - return CocoInstanceDataset.from_common(labels=labels, - image_root=image_root, - max_workers=max_workers).dict() - - @staticmethod - def serialize_panoptic(labels: LabelCollection, - image_root: Union[Path, str], - mask_root: Union[Path, str], - all_stuff: bool = False, - ignore_existing_data=False, - max_workers: int = 8) -> Dict[str, Any]: - """ - Convert a Labelbox LabelCollection into an mscoco dataset. - This function will only convert masks, polygons, and rectangles. - Masks will be converted into individual instances. - Use deserialize_panoptic to prevent masks from being split apart. - - Args: - labels: A collection of labels to convert - image_root: Where to save images to - mask_root: Where to save segmentation masks to - all_stuff: If rectangle or polygon annotations are encountered, they will be treated as instances. - To convert them to stuff class set `all_stuff=True`. - ignore_existing_data: Whether or not to raise an exception if images already exist. - This exists only to support detectons panoptic fpn model which requires two mscoco payloads for the same images. - max_workers : Number of workers to process dataset with. A value of 0 will process all data in the main process. - Returns: - A dictionary containing labels in the coco panoptic format. - """ - image_root = create_path_if_not_exists(image_root, ignore_existing_data) - mask_root = create_path_if_not_exists(mask_root, ignore_existing_data) - return CocoPanopticDataset.from_common(labels=labels, - image_root=image_root, - mask_root=mask_root, - all_stuff=all_stuff, - max_workers=max_workers).dict() - - @staticmethod - def deserialize_panoptic(json_data: Dict[str, Any], image_root: Union[Path, - str], - mask_root: Union[Path, str]) -> LabelGenerator: - """ - Convert coco panoptic data into the labelbox format (as a LabelGenerator). - - Args: - json_data: panoptic data as a dict - image_root: Path to local images that are referenced by the panoptic json - mask_root: Path to local segmentation masks that are referenced by the panoptic json - Returns: - LabelGenerator - """ - image_root = validate_path(image_root, 'image_root') - mask_root = validate_path(mask_root, 'mask_root') - objs = CocoPanopticDataset(**json_data) - gen = objs.to_common(image_root, mask_root) - return LabelGenerator(data=gen) - - @staticmethod - def deserialize_instances(json_data: Dict[str, Any], - image_root: Path) -> LabelGenerator: - """ - Convert coco object data into the labelbox format (as a LabelGenerator). - - Args: - json_data: coco object data as a dict - image_root: Path to local images that are referenced by the coco object json - Returns: - LabelGenerator - """ - image_root = validate_path(image_root, 'image_root') - objs = CocoInstanceDataset(**json_data) - gen = objs.to_common(image_root) - return LabelGenerator(data=gen) - ----- -labelbox/data/serialization/coco/__init__.py -from .converter import COCOConverter - ----- -labelbox/data/serialization/coco/instance_dataset.py -# https://cocodataset.org/#format-data - -from concurrent.futures import ProcessPoolExecutor, as_completed -from typing import Any, Dict, List, Tuple, Optional -from pathlib import Path - -import numpy as np -from tqdm import tqdm -from labelbox import pydantic_compat - -from ...annotation_types import ImageData, MaskData, Mask, ObjectAnnotation, Label, Polygon, Point, Rectangle -from ...annotation_types.collection import LabelCollection -from .categories import Categories, hash_category_name -from .annotation import COCOObjectAnnotation, RLE, get_annotation_lookup, rle_decoding -from .image import CocoImage, get_image, get_image_id - - -def mask_to_coco_object_annotation( - annotation: ObjectAnnotation, annot_idx: int, image_id: int, - category_id: int) -> Optional[COCOObjectAnnotation]: - # This is going to fill any holes into the multipolygon - # If you need to support holes use the panoptic data format - shapely = annotation.value.shapely.simplify(1).buffer(0) - if shapely.is_empty: - return - - xmin, ymin, xmax, ymax = shapely.bounds - # Iterate over polygon once or multiple polygon for each item - area = shapely.area - - return COCOObjectAnnotation( - id=annot_idx, - image_id=image_id, - category_id=category_id, - segmentation=[ - np.array(s.exterior.coords).ravel().tolist() - for s in ([shapely] if shapely.type == "Polygon" else shapely.geoms) - ], - area=area, - bbox=[xmin, ymin, xmax - xmin, ymax - ymin], - iscrowd=0) - - -def vector_to_coco_object_annotation(annotation: ObjectAnnotation, - annot_idx: int, image_id: int, - category_id: int) -> COCOObjectAnnotation: - shapely = annotation.value.shapely - xmin, ymin, xmax, ymax = shapely.bounds - segmentation = [] - if isinstance(annotation.value, Polygon): - for point in annotation.value.points: - segmentation.extend([point.x, point.y]) - else: - box = annotation.value - segmentation.extend([ - box.start.x, box.start.y, box.end.x, box.start.y, box.end.x, - box.end.y, box.start.x, box.end.y - ]) - - return COCOObjectAnnotation(id=annot_idx, - image_id=image_id, - category_id=category_id, - segmentation=[segmentation], - area=shapely.area, - bbox=[xmin, ymin, xmax - xmin, ymax - ymin], - iscrowd=0) - - -def rle_to_common(class_annotations: COCOObjectAnnotation, - class_name: str) -> ObjectAnnotation: - mask = rle_decoding(class_annotations.segmentation.counts, - *class_annotations.segmentation.size[::-1]) - return ObjectAnnotation(name=class_name, - value=Mask(mask=MaskData.from_2D_arr(mask), - color=[1, 1, 1])) - - -def segmentations_to_common(class_annotations: COCOObjectAnnotation, - class_name: str) -> List[ObjectAnnotation]: - # Technically it is polygons. But the key in coco is called segmentations.. - annotations = [] - for points in class_annotations.segmentation: - annotations.append( - ObjectAnnotation(name=class_name, - value=Polygon(points=[ - Point(x=points[i], y=points[i + 1]) - for i in range(0, len(points), 2) - ]))) - return annotations - - -def object_annotation_to_coco( - annotation: ObjectAnnotation, annot_idx: int, image_id: int, - category_id: int) -> Optional[COCOObjectAnnotation]: - if isinstance(annotation.value, Mask): - return mask_to_coco_object_annotation(annotation, annot_idx, image_id, - category_id) - elif isinstance(annotation.value, (Polygon, Rectangle)): - return vector_to_coco_object_annotation(annotation, annot_idx, image_id, - category_id) - else: - return None - - -def process_label( - label: Label, - idx: int, - image_root: str, - max_annotations_per_image=10000 -) -> Tuple[np.ndarray, List[COCOObjectAnnotation], Dict[str, str]]: - annot_idx = idx * max_annotations_per_image - image_id = get_image_id(label, idx) - image = get_image(label, image_root, image_id) - coco_annotations = [] - annotation_lookup = get_annotation_lookup(label.annotations) - categories = {} - for class_name in annotation_lookup: - for annotation in annotation_lookup[class_name]: - category_id = categories.get(annotation.name) or hash_category_name( - annotation.name) - coco_annotation = object_annotation_to_coco(annotation, annot_idx, - image_id, category_id) - if coco_annotation is not None: - coco_annotations.append(coco_annotation) - if annotation.name not in categories: - categories[annotation.name] = category_id - annot_idx += 1 - - return image, coco_annotations, categories - - -class CocoInstanceDataset(pydantic_compat.BaseModel): - info: Dict[str, Any] = {} - images: List[CocoImage] - annotations: List[COCOObjectAnnotation] - categories: List[Categories] - - @classmethod - def from_common(cls, - labels: LabelCollection, - image_root: Path, - max_workers=8): - all_coco_annotations = [] - categories = {} - images = [] - futures = [] - coco_categories = {} - - if max_workers: - with ProcessPoolExecutor(max_workers=max_workers) as exc: - futures = [ - exc.submit(process_label, label, idx, image_root) - for idx, label in enumerate(labels) - ] - results = [ - future.result() for future in tqdm(as_completed(futures)) - ] - else: - - results = [ - process_label(label, idx, image_root) - for idx, label in enumerate(labels) - ] - - for result in results: - images.append(result[0]) - all_coco_annotations.extend(result[1]) - coco_categories.update(result[2]) - - category_mapping = { - category_id: idx + 1 - for idx, category_id in enumerate(coco_categories.values()) - } - categories = [ - Categories(id=category_mapping[idx], - name=name, - supercategory='all', - isthing=1) for name, idx in coco_categories.items() - ] - for annot in all_coco_annotations: - annot.category_id = category_mapping[annot.category_id] - - return CocoInstanceDataset(info={'image_root': image_root}, - images=images, - annotations=all_coco_annotations, - categories=categories) - - def to_common(self, image_root): - category_lookup = { - category.id: category for category in self.categories - } - annotation_lookup = get_annotation_lookup(self.annotations) - - for image in self.images: - im_path = Path(image_root, image.file_name) - if not im_path.exists(): - raise ValueError( - f"Cannot find file {im_path}. Make sure `image_root` is set properly" - ) - - data = ImageData(file_path=str(im_path)) - annotations = [] - for class_annotations in annotation_lookup[image.id]: - if isinstance(class_annotations.segmentation, RLE): - annotations.append( - rle_to_common( - class_annotations, category_lookup[ - class_annotations.category_id].name)) - elif isinstance(class_annotations.segmentation, list): - annotations.extend( - segmentations_to_common( - class_annotations, category_lookup[ - class_annotations.category_id].name)) - yield Label(data=data, annotations=annotations) - ----- -labelbox/data/serialization/coco/path.py -from labelbox import pydantic_compat -from pathlib import Path - - -class PathSerializerMixin(pydantic_compat.BaseModel): - - def dict(self, *args, **kwargs): - res = super().dict(*args, **kwargs) - return {k: str(v) if isinstance(v, Path) else v for k, v in res.items()} - ----- -labelbox/data/serialization/coco/image.py -from pathlib import Path - -from typing import Optional, Tuple -from PIL import Image -import imagesize - -from .path import PathSerializerMixin -from labelbox.data.annotation_types import Label - - -class CocoImage(PathSerializerMixin): - id: int - width: int - height: int - file_name: Path - license: Optional[int] = None - flickr_url: Optional[str] = None - coco_url: Optional[str] = None - - -def get_image_id(label: Label, idx: int) -> int: - if label.data.file_path is not None: - file_name = label.data.file_path.replace(".jpg", "") - if file_name.isdecimal(): - return file_name - return idx - - -def get_image(label: Label, image_root: Path, image_id: str) -> CocoImage: - path = Path(image_root, f"{image_id}.jpg") - if not path.exists(): - im = Image.fromarray(label.data.value) - im.save(path) - w, h = im.size - else: - w, h = imagesize.get(str(path)) - return CocoImage(id=image_id, width=w, height=h, file_name=Path(path.name)) - - -def id_to_rgb(id: int) -> Tuple[int, int, int]: - digits = [] - for _ in range(3): - digits.append(id % 256) - id //= 256 - return digits - - -def rgb_to_id(red: int, green: int, blue: int) -> int: - id = blue * 256 * 256 - id += (green * 256) - id += red - return id - ----- -labelbox/data/serialization/labelbox_v1/classification.py -from typing import List, Union - -from labelbox import pydantic_compat - -from .feature import LBV1Feature -from ...annotation_types.annotation import ClassificationAnnotation -from ...annotation_types.classification import Checklist, ClassificationAnswer, Radio, Text, Dropdown -from ...annotation_types.types import Cuid - - -class LBV1ClassificationAnswer(LBV1Feature): - - def to_common(self) -> ClassificationAnswer: - return ClassificationAnswer(feature_schema_id=self.schema_id, - name=self.title, - keyframe=self.keyframe, - extra={ - 'feature_id': self.feature_id, - 'value': self.value - }) - - @classmethod - def from_common( - cls, - answer: ClassificationAnnotation) -> "LBV1ClassificationAnswer": - return cls(schema_id=answer.feature_schema_id, - title=answer.name, - value=answer.extra.get('value'), - feature_id=answer.extra.get('feature_id'), - keyframe=answer.keyframe) - - -class LBV1Radio(LBV1Feature): - answer: LBV1ClassificationAnswer - - def to_common(self) -> Radio: - return Radio(answer=self.answer.to_common()) - - @classmethod - def from_common(cls, radio: Radio, feature_schema_id: Cuid, - **extra) -> "LBV1Radio": - return cls(schema_id=feature_schema_id, - answer=LBV1ClassificationAnswer.from_common(radio.answer), - **extra) - - -class LBV1Checklist(LBV1Feature): - answers: List[LBV1ClassificationAnswer] - - def to_common(self) -> Checklist: - return Checklist(answer=[answer.to_common() for answer in self.answers]) - - @classmethod - def from_common(cls, checklist: Checklist, feature_schema_id: Cuid, - **extra) -> "LBV1Checklist": - return cls(schema_id=feature_schema_id, - answers=[ - LBV1ClassificationAnswer.from_common(answer) - for answer in checklist.answer - ], - **extra) - - -class LBV1Dropdown(LBV1Feature): - answer: List[LBV1ClassificationAnswer] - - def to_common(self) -> Dropdown: - return Dropdown(answer=[answer.to_common() for answer in self.answer]) - - @classmethod - def from_common(cls, dropdown: Dropdown, feature_schema_id: Cuid, - **extra) -> "LBV1Dropdown": - return cls(schema_id=feature_schema_id, - answer=[ - LBV1ClassificationAnswer.from_common(answer) - for answer in dropdown.answer - ], - **extra) - - -class LBV1Text(LBV1Feature): - answer: str - - def to_common(self) -> Text: - return Text(answer=self.answer) - - @classmethod - def from_common(cls, text: Text, feature_schema_id: Cuid, - **extra) -> "LBV1Text": - return cls(schema_id=feature_schema_id, answer=text.answer, **extra) - - -class LBV1Classifications(pydantic_compat.BaseModel): - classifications: List[Union[LBV1Text, LBV1Radio, LBV1Dropdown, - LBV1Checklist]] = [] - - def to_common(self) -> List[ClassificationAnnotation]: - classifications = [ - ClassificationAnnotation(value=classification.to_common(), - name=classification.title, - feature_schema_id=classification.schema_id, - extra={ - 'value': classification.value, - 'feature_id': classification.feature_id - }) - for classification in self.classifications - ] - return classifications - - @classmethod - def from_common( - cls, annotations: List[ClassificationAnnotation] - ) -> "LBV1Classifications": - classifications = [] - for annotation in annotations: - classification = cls.lookup_classification(annotation) - if classification is not None: - classifications.append( - classification.from_common(annotation.value, - annotation.feature_schema_id, - **annotation.extra)) - else: - raise TypeError(f"Unexpected type {type(annotation.value)}") - return cls(classifications=classifications) - - @staticmethod - def lookup_classification( - annotation: ClassificationAnnotation - ) -> Union[LBV1Text, LBV1Checklist, LBV1Radio, LBV1Checklist]: - return { - Text: LBV1Text, - Dropdown: LBV1Dropdown, - Checklist: LBV1Checklist, - Radio: LBV1Radio - }.get(type(annotation.value)) - ----- -labelbox/data/serialization/labelbox_v1/converter.py -from labelbox.data.serialization.labelbox_v1.objects import LBV1Mask -from typing import Any, Dict, Generator, Iterable, Union -import logging - -from labelbox import parser -import requests -from copy import deepcopy -from requests.exceptions import HTTPError -from google.api_core import retry - -import labelbox -from .label import LBV1Label -from ...annotation_types.collection import (LabelCollection, LabelGenerator, - PrefetchGenerator) - -logger = logging.getLogger(__name__) - - -class LBV1Converter: - - @staticmethod - def deserialize_video(json_data: Union[str, Iterable[Dict[str, Any]]], - client: "labelbox.Client") -> LabelGenerator: - """ - Converts a labelbox video export into the common labelbox format. - - Args: - json_data: An iterable representing the labelbox video export. - client: The labelbox client for downloading video annotations - Returns: - LabelGenerator containing the video data. - """ - label_generator = (LBV1Label(**example).to_common() - for example in LBV1VideoIterator(json_data, client) - if example['Label']) - return LabelGenerator(data=label_generator) - - @staticmethod - def deserialize( - json_data: Union[str, Iterable[Dict[str, Any]]]) -> LabelGenerator: - """ - Converts a labelbox export (non-video) into the common labelbox format. - - Args: - json_data: An iterable representing the labelbox export. - Returns: - LabelGenerator containing the export data. - """ - - def label_generator(): - for example in json_data: - if 'frames' in example['Label']: - raise ValueError( - "Use `LBV1Converter.deserialize_video` to process video" - ) - - if example['Label']: - # Don't construct empty dict - yield LBV1Label(**example).to_common() - - return LabelGenerator(data=label_generator()) - - @staticmethod - def serialize( - labels: LabelCollection) -> Generator[Dict[str, Any], None, None]: - """ - Converts a labelbox common object to the labelbox json export format - - Note that any metric annotations will not be written since they are not defined in the LBV1 format. - - Args: - labels: Either a list of Label objects or a LabelGenerator (LabelCollection) - Returns: - A generator for accessing the labelbox json export representation of the data - """ - for label in labels: - res = LBV1Label.from_common(label) - yield res.dict(by_alias=True) - - -class LBV1VideoIterator(PrefetchGenerator): - """ - Generator that fetches video annotations in the background to be faster. - """ - - def __init__(self, examples, client): - self.client = client - super().__init__(examples) - - def _process(self, value): - value = deepcopy(value) - if 'frames' in value['Label']: - req = self._request(value) - value['Label'] = parser.loads(req) - return value - - @retry.Retry(predicate=retry.if_exception_type(HTTPError)) - def _request(self, value): - req = requests.get( - value['Label']['frames'], - headers={"Authorization": f"Bearer {self.client.api_key}"}) - if req.status_code == 401: - raise labelbox.exceptions.AuthenticationError("Invalid API key") - req.raise_for_status() - return req.text - ----- -labelbox/data/serialization/labelbox_v1/__init__.py -from .converter import LBV1Converter - ----- -labelbox/data/serialization/labelbox_v1/feature.py -from typing import Optional - -from labelbox import pydantic_compat - -from labelbox.utils import camel_case -from ...annotation_types.types import Cuid - - -class LBV1Feature(pydantic_compat.BaseModel): - keyframe: Optional[bool] = None - title: str = None - value: Optional[str] = None - schema_id: Optional[Cuid] = None - feature_id: Optional[Cuid] = None - - @pydantic_compat.root_validator - def check_ids(cls, values): - if values.get('value') is None: - values['value'] = values['title'] - return values - - def dict(self, *args, **kwargs): - res = super().dict(*args, **kwargs) - # This means these are no video frames .. - if self.keyframe is None: - res.pop('keyframe') - return res - - class Config: - allow_population_by_field_name = True - alias_generator = camel_case - ----- -labelbox/data/serialization/labelbox_v1/label.py -from labelbox.data.annotation_types.data.tiled_image import TiledImageData -from labelbox.utils import camel_case -from typing import List, Optional, Union, Dict, Any - -from labelbox import pydantic_compat - -from ...annotation_types.annotation import (ClassificationAnnotation, - ObjectAnnotation) -from ...annotation_types.video import VideoClassificationAnnotation, VideoObjectAnnotation -from ...annotation_types.data import ImageData, TextData, VideoData -from ...annotation_types.label import Label -from .classification import LBV1Classifications -from .objects import LBV1Objects, LBV1TextEntity - - -class LBV1LabelAnnotations(LBV1Classifications, LBV1Objects): - - def to_common( - self) -> List[Union[ObjectAnnotation, ClassificationAnnotation]]: - classifications = LBV1Classifications.to_common(self) - objects = LBV1Objects.to_common(self) - return [*objects, *classifications] - - @classmethod - def from_common( - cls, annotations: List[Union[ClassificationAnnotation, - ObjectAnnotation]] - ) -> "LBV1LabelAnnotations": - - objects = LBV1Objects.from_common( - [x for x in annotations if isinstance(x, ObjectAnnotation)]) - classifications = LBV1Classifications.from_common( - [x for x in annotations if isinstance(x, ClassificationAnnotation)]) - return cls(**objects.dict(), **classifications.dict()) - - -class LBV1LabelAnnotationsVideo(LBV1LabelAnnotations): - frame_number: int = pydantic_compat.Field(..., alias='frameNumber') - - def to_common( - self - ) -> List[Union[VideoClassificationAnnotation, VideoObjectAnnotation]]: - classifications = [ - VideoClassificationAnnotation( - value=classification.to_common(), - frame=self.frame_number, - name=classification.title, - feature_schema_id=classification.schema_id) - for classification in self.classifications - ] - - objects = [ - VideoObjectAnnotation(value=obj.to_common(), - keyframe=obj.keyframe, - classifications=[ - ClassificationAnnotation( - value=cls.to_common(), - feature_schema_id=cls.schema_id, - name=cls.title, - extra={ - 'feature_id': cls.feature_id, - 'title': cls.title, - 'value': cls.value, - }) for cls in obj.classifications - ], - name=obj.title, - frame=self.frame_number, - alternative_name=obj.value, - feature_schema_id=obj.schema_id, - extra={ - 'value': obj.value, - 'instanceURI': obj.instanceURI, - 'color': obj.color, - 'feature_id': obj.feature_id, - }) for obj in self.objects - ] - return [*classifications, *objects] - - @classmethod - def from_common( - cls, annotations: List[Union[VideoObjectAnnotation, - VideoClassificationAnnotation]] - ) -> "LBV1LabelAnnotationsVideo": - by_frames = {} - for annotation in annotations: - if annotation.frame in by_frames: - by_frames[annotation.frame].append(annotation) - else: - by_frames[annotation.frame] = [annotation] - - result = [] - for frame in by_frames: - converted = LBV1LabelAnnotations.from_common( - annotations=by_frames[frame]) - result.append( - LBV1LabelAnnotationsVideo( - frame_number=frame, - objects=converted.objects, - classifications=converted.classifications)) - - return result - - class Config: - allow_population_by_field_name = True - - -class Review(pydantic_compat.BaseModel): - score: int - id: str - created_at: str - created_by: str - label_id: Optional[str] = None - - class Config: - alias_generator = camel_case - - -Extra = lambda name: pydantic_compat.Field(None, alias=name, extra_field=True) - - -class LBV1Label(pydantic_compat.BaseModel): - label: Union[LBV1LabelAnnotations, - List[LBV1LabelAnnotationsVideo]] = pydantic_compat.Field( - ..., alias='Label') - data_row_id: str = pydantic_compat.Field(..., alias="DataRow ID") - row_data: str = pydantic_compat.Field(None, alias="Labeled Data") - id: Optional[str] = pydantic_compat.Field(None, alias='ID') - external_id: Optional[str] = pydantic_compat.Field(None, - alias="External ID") - data_row_media_attributes: Optional[Dict[str, Any]] = pydantic_compat.Field( - {}, alias="Media Attributes") - data_row_metadata: Optional[List[Dict[str, Any]]] = pydantic_compat.Field( - [], alias="DataRow Metadata") - - created_by: Optional[str] = Extra('Created By') - project_name: Optional[str] = Extra('Project Name') - created_at: Optional[str] = Extra('Created At') - updated_at: Optional[str] = Extra('Updated At') - seconds_to_label: Optional[float] = Extra('Seconds to Label') - agreement: Optional[float] = Extra('Agreement') - benchmark_agreement: Optional[float] = Extra('Benchmark Agreement') - benchmark_id: Optional[str] = Extra('Benchmark ID') - dataset_name: Optional[str] = Extra('Dataset Name') - reviews: Optional[List[Review]] = Extra('Reviews') - label_url: Optional[str] = Extra('View Label') - has_open_issues: Optional[float] = Extra('Has Open Issues') - skipped: Optional[bool] = Extra('Skipped') - media_type: Optional[str] = Extra('media_type') - data_split: Optional[str] = Extra('Data Split') - global_key: Optional[str] = Extra('Global Key') - - def to_common(self) -> Label: - if isinstance(self.label, list): - annotations = [] - for lbl in self.label: - annotations.extend(lbl.to_common()) - else: - annotations = self.label.to_common() - - return Label(data=self._data_row_to_common(), - uid=self.id, - annotations=annotations, - extra={ - field.alias: getattr(self, field_name) - for field_name, field in self.__fields__.items() - if field.field_info.extra.get('extra_field') - }) - - @classmethod - def from_common(cls, label: Label): - if isinstance(label.annotations[0], - (VideoObjectAnnotation, VideoClassificationAnnotation)): - label_ = LBV1LabelAnnotationsVideo.from_common(label.annotations) - else: - label_ = LBV1LabelAnnotations.from_common(label.annotations) - return LBV1Label(label=label_, - id=label.uid, - data_row_id=label.data.uid, - row_data=label.data.url, - external_id=label.data.external_id, - data_row_media_attributes=label.data.media_attributes, - data_row_metadata=label.data.metadata, - **label.extra) - - def _data_row_to_common( - self) -> Union[ImageData, TextData, VideoData, TiledImageData]: - # Use data row information to construct the appropriate annotation type - data_row_info = { - 'url' if self._is_url() else 'text': self.row_data, - 'external_id': self.external_id, - 'uid': self.data_row_id, - 'media_attributes': self.data_row_media_attributes, - 'metadata': self.data_row_metadata - } - - self.media_type = self.media_type or self._infer_media_type() - media_mapping = { - 'text': TextData, - 'image': ImageData, - 'video': VideoData - } - if self.media_type not in media_mapping: - raise ValueError( - f"Annotation types are only supported for {list(media_mapping)} media types." - f" Found {self.media_type}.") - return media_mapping[self.media_type](**data_row_info) - - def _infer_media_type(self) -> str: - # Determines the data row type based on the label content - if isinstance(self.label, list): - return 'video' - if self._has_text_annotations(): - return 'text' - elif self._has_object_annotations(): - return 'image' - else: - if self._row_contains((".jpg", ".png", ".jpeg")) and self._is_url(): - return 'image' - elif (self._row_contains((".txt", ".text", ".html")) and - self._is_url()) or not self._is_url(): - return 'text' - else: - # This condition will occur when a data row url does not contain a file extension - # and the label does not contain object annotations that indicate the media type. - # As a temporary workaround you can explicitly set the media_type - # in each label json payload before converting. - # We will eventually provide the media type in the export. - raise TypeError( - f"Can't infer data type from row data. row_data: {self.row_data[:200]}" - ) - - def _has_object_annotations(self) -> bool: - return len(self.label.objects) > 0 - - def _has_text_annotations(self) -> bool: - return len([ - annotation for annotation in self.label.objects - if isinstance(annotation, LBV1TextEntity) - ]) > 0 - - def _row_contains(self, substrs) -> bool: - lower_row_data = self.row_data.lower() - return any([substr in lower_row_data for substr in substrs]) - - def _is_url(self) -> bool: - return self.row_data.startswith( - ("http://", "https://", "gs://", - "s3://")) or "tileLayerUrl" in self.row_data - - class Config: - allow_population_by_field_name = True - ----- -labelbox/data/serialization/labelbox_v1/objects.py -from typing import Any, Dict, List, Optional, Union, Type -try: - from typing import Literal -except: - from typing_extensions import Literal - -from labelbox import pydantic_compat -import numpy as np - -from .classification import LBV1Checklist, LBV1Classifications, LBV1Radio, LBV1Text, LBV1Dropdown -from .feature import LBV1Feature -from ...annotation_types.annotation import (ClassificationAnnotation, - ObjectAnnotation) -from ...annotation_types.data import MaskData -from ...annotation_types.geometry import Line, Mask, Point, Polygon, Rectangle -from ...annotation_types.ner import TextEntity -from ...annotation_types.types import Cuid - - -class LBV1ObjectBase(LBV1Feature): - color: Optional[str] = None - instanceURI: Optional[str] = None - classifications: List[Union[LBV1Text, LBV1Radio, LBV1Dropdown, - LBV1Checklist]] = [] - page: Optional[int] = None - unit: Optional[str] = None - - def dict(self, *args, **kwargs) -> Dict[str, Any]: - res = super().dict(*args, **kwargs) - # This means these are not video frames .. - if self.instanceURI is None: - res.pop('instanceURI') - return res - - @pydantic_compat.validator('classifications', pre=True) - def validate_subclasses(cls, value, field): - # checklist subclasses create extra unessesary nesting. So we just remove it. - if isinstance(value, list) and len(value): - subclasses = [] - for v in value: - # this is due to Checklists providing extra brackets []. We grab every item - # in the brackets if this is the case - if isinstance(v, list): - for inner_v in v: - subclasses.append(inner_v) - else: - subclasses.append(v) - return subclasses - return value - - -class TIPointCoordinate(pydantic_compat.BaseModel): - coordinates: List[float] - - -class TILineCoordinate(pydantic_compat.BaseModel): - coordinates: List[List[float]] - - -class TIPolygonCoordinate(pydantic_compat.BaseModel): - coordinates: List[List[List[float]]] - - -class TIRectangleCoordinate(pydantic_compat.BaseModel): - coordinates: List[List[List[float]]] - - -class LBV1TIPoint(LBV1ObjectBase): - object_type: Literal['point'] = pydantic_compat.Field(..., alias='type') - geometry: TIPointCoordinate - - def to_common(self) -> Point: - lng, lat = self.geometry.coordinates - return Point(x=lng, y=lat) - - -class LBV1TILine(LBV1ObjectBase): - object_type: Literal['polyline'] = pydantic_compat.Field(..., alias='type') - geometry: TILineCoordinate - - def to_common(self) -> Line: - return Line(points=[ - Point(x=coord[0], y=coord[1]) for coord in self.geometry.coordinates - ]) - - -class LBV1TIPolygon(LBV1ObjectBase): - object_type: Literal['polygon'] = pydantic_compat.Field(..., alias='type') - geometry: TIPolygonCoordinate - - def to_common(self) -> Polygon: - for coord_list in self.geometry.coordinates: - return Polygon( - points=[Point(x=coord[0], y=coord[1]) for coord in coord_list]) - - -class LBV1TIRectangle(LBV1ObjectBase): - object_type: Literal['rectangle'] = pydantic_compat.Field(..., alias='type') - geometry: TIRectangleCoordinate - - def to_common(self) -> Rectangle: - coord_list = np.array(self.geometry.coordinates[0]) - - min_x, max_x = np.min(coord_list[:, 0]), np.max(coord_list[:, 0]) - min_y, max_y = np.min(coord_list[:, 1]), np.max(coord_list[:, 1]) - - start = [min_x, min_y] - end = [max_x, max_y] - - return Rectangle(start=Point(x=start[0], y=start[1]), - end=Point(x=end[0], y=end[1])) - - -class _Point(pydantic_compat.BaseModel): - x: float - y: float - - -class _Box(pydantic_compat.BaseModel): - top: float - left: float - height: float - width: float - - -class LBV1Rectangle(LBV1ObjectBase): - bbox: _Box - - def to_common(self) -> Rectangle: - return Rectangle(start=Point(x=self.bbox.left, y=self.bbox.top), - end=Point(x=self.bbox.left + self.bbox.width, - y=self.bbox.top + self.bbox.height)) - - @classmethod - def from_common(cls, rectangle: Rectangle, - classifications: List[ClassificationAnnotation], - feature_schema_id: Cuid, title: str, - extra: Dict[str, Any]) -> "LBV1Rectangle": - return cls(bbox=_Box( - top=rectangle.start.y, - left=rectangle.start.x, - height=rectangle.end.y - rectangle.start.y, - width=rectangle.end.x - rectangle.start.x, - ), - schema_id=feature_schema_id, - title=title, - classifications=classifications, - **extra) - - -class LBV1Polygon(LBV1ObjectBase): - polygon: List[_Point] - - def to_common(self) -> Polygon: - return Polygon(points=[Point(x=p.x, y=p.y) for p in self.polygon]) - - @classmethod - def from_common(cls, polygon: Polygon, - classifications: List[ClassificationAnnotation], - feature_schema_id: Cuid, title: str, - extra: Dict[str, Any]) -> "LBV1Polygon": - return cls( - polygon=[ - _Point(x=point.x, y=point.y) for point in polygon.points[:-1] - ], # drop closing point - classifications=classifications, - schema_id=feature_schema_id, - title=title, - **extra) - - -class LBV1Point(LBV1ObjectBase): - point: _Point - - def to_common(self) -> Point: - return Point(x=self.point.x, y=self.point.y) - - @classmethod - def from_common(cls, point: Point, - classifications: List[ClassificationAnnotation], - feature_schema_id: Cuid, title: str, - extra: Dict[str, Any]) -> "LBV1Point": - return cls(point=_Point(x=point.x, y=point.y), - classifications=classifications, - schema_id=feature_schema_id, - title=title, - **extra) - - -class LBV1Line(LBV1ObjectBase): - line: List[_Point] - - def to_common(self) -> Line: - return Line(points=[Point(x=p.x, y=p.y) for p in self.line]) - - @classmethod - def from_common(cls, polygon: Line, - classifications: List[ClassificationAnnotation], - feature_schema_id: Cuid, title: str, - extra: Dict[str, Any]) -> "LBV1Line": - return cls( - line=[_Point(x=point.x, y=point.y) for point in polygon.points], - classifications=classifications, - schema_id=feature_schema_id, - title=title, - **extra) - - -class LBV1Mask(LBV1ObjectBase): - instanceURI: str - - def to_common(self) -> Mask: - return Mask(mask=MaskData(url=self.instanceURI), color=(255, 255, 255)) - - @classmethod - def from_common(cls, mask: Mask, - classifications: List[ClassificationAnnotation], - feature_schema_id: Cuid, title: str, - extra: Dict[str, Any]) -> "LBV1Mask": - if mask.mask.url is None: - raise ValueError( - "Mask does not have a url. Use `LabelGenerator.add_url_to_masks`, or `Label.add_url_to_masks`." - ) - return cls(instanceURI=mask.mask.url, - classifications=classifications, - schema_id=feature_schema_id, - title=title, - **{ - k: v for k, v in extra.items() if k != 'instanceURI' - }) - - -class _TextPoint(pydantic_compat.BaseModel): - start: int - end: int - - -class _Location(pydantic_compat.BaseModel): - location: _TextPoint - - -class LBV1TextEntity(LBV1ObjectBase): - data: _Location - format: str = "text.location" - version: int = 1 - - def to_common(self) -> TextEntity: - return TextEntity( - start=self.data.location.start, - end=self.data.location.end, - ) - - @classmethod - def from_common(cls, text_entity: TextEntity, - classifications: List[ClassificationAnnotation], - feature_schema_id: Cuid, title: str, - extra: Dict[str, Any]) -> "LBV1TextEntity": - return cls(data=_Location( - location=_TextPoint(start=text_entity.start, end=text_entity.end)), - classifications=classifications, - schema_id=feature_schema_id, - title=title, - **extra) - - -class LBV1Objects(pydantic_compat.BaseModel): - objects: List[Union[ - LBV1Line, - LBV1Point, - LBV1Polygon, - LBV1Rectangle, - LBV1TextEntity, - LBV1Mask, - LBV1TIPoint, - LBV1TILine, - LBV1TIPolygon, - LBV1TIRectangle, - ]] - - def to_common(self) -> List[ObjectAnnotation]: - objects = [ - ObjectAnnotation(value=obj.to_common(), - classifications=[ - ClassificationAnnotation( - value=cls.to_common(), - feature_schema_id=cls.schema_id, - name=cls.title, - extra={ - 'feature_id': cls.feature_id, - 'title': cls.title, - 'value': cls.value - }) for cls in obj.classifications - ], - name=obj.title, - feature_schema_id=obj.schema_id, - extra={ - 'instanceURI': obj.instanceURI, - 'color': obj.color, - 'feature_id': obj.feature_id, - 'value': obj.value, - 'page': obj.page, - 'unit': obj.unit, - }) for obj in self.objects - ] - return objects - - @classmethod - def from_common(cls, annotations: List[ObjectAnnotation]) -> "LBV1Objects": - objects = [] - for annotation in annotations: - obj = cls.lookup_object(annotation) - subclasses = LBV1Classifications.from_common( - annotation.classifications).classifications - - objects.append( - obj.from_common( - annotation.value, subclasses, annotation.feature_schema_id, - annotation.name, { - 'keyframe': getattr(annotation, 'keyframe', None), - **annotation.extra - })) - return cls(objects=objects) - - @staticmethod - def lookup_object( - annotation: ObjectAnnotation - ) -> Type[Union[LBV1Line, LBV1Point, LBV1Polygon, LBV1Rectangle, LBV1Mask, - LBV1TextEntity]]: - result = { - Line: LBV1Line, - Point: LBV1Point, - Polygon: LBV1Polygon, - Rectangle: LBV1Rectangle, - Mask: LBV1Mask, - TextEntity: LBV1TextEntity - }.get(type(annotation.value)) - if result is None: - raise TypeError(f"Unexpected type {type(annotation.value)}") - return result - ----- -labelbox/data/annotation_types/annotation.py -from typing import List, Union - -from labelbox.data.annotation_types.base_annotation import BaseAnnotation -from labelbox.data.annotation_types.geometry.geometry import Geometry - -from labelbox.data.mixins import ConfidenceMixin, CustomMetricsMixin - -from labelbox.data.annotation_types.classification.classification import ClassificationAnnotation -from .ner import DocumentEntity, TextEntity, ConversationEntity - - -class ObjectAnnotation(BaseAnnotation, ConfidenceMixin, CustomMetricsMixin): - """Generic localized annotation (non classifications) - - >>> ObjectAnnotation( - >>> value=Rectangle( - >>> start=Point(x=0, y=0), - >>> end=Point(x=1, y=1) - >>> ), - >>> feature_schema_id="my-feature-schema-id" - >>> ) - - Args: - name (Optional[str]) - feature_schema_id (Optional[Cuid]) - value (Union[TextEntity, Geometry]): Localization of the annotation - classifications (Optional[List[ClassificationAnnotation]]): Optional sub classification of the annotation - extra (Dict[str, Any]) - """ - - value: Union[TextEntity, ConversationEntity, DocumentEntity, Geometry] - classifications: List[ClassificationAnnotation] = [] - ----- -labelbox/data/annotation_types/relationship.py -from labelbox import pydantic_compat -from enum import Enum -from labelbox.data.annotation_types.annotation import BaseAnnotation, ObjectAnnotation - - -class Relationship(pydantic_compat.BaseModel): - - class Type(Enum): - UNIDIRECTIONAL = "unidirectional" - BIDIRECTIONAL = "bidirectional" - - source: ObjectAnnotation - target: ObjectAnnotation - type: Type = Type.UNIDIRECTIONAL - - -class RelationshipAnnotation(BaseAnnotation): - value: Relationship - ----- -labelbox/data/annotation_types/__init__.py -from .geometry import Line -from .geometry import Point -from .geometry import Mask -from .geometry import Polygon -from .geometry import Rectangle -from .geometry import Geometry -from .geometry import DocumentRectangle -from .geometry import RectangleUnit - -from .annotation import ClassificationAnnotation -from .annotation import ObjectAnnotation - -from .relationship import RelationshipAnnotation -from .relationship import Relationship - -from .video import VideoClassificationAnnotation -from .video import VideoObjectAnnotation -from .video import DICOMObjectAnnotation -from .video import GroupKey -from .video import MaskFrame -from .video import MaskInstance -from .video import VideoMaskAnnotation -from .video import DICOMMaskAnnotation - -from .ner import ConversationEntity -from .ner import DocumentEntity -from .ner import DocumentTextSelection -from .ner import TextEntity - -from .classification import Checklist -from .classification import ClassificationAnswer -from .classification import Dropdown -from .classification import Radio -from .classification import Text - -from .data import AudioData -from .data import ConversationData -from .data import DicomData -from .data import DocumentData -from .data import HTMLData -from .data import ImageData -from .data import MaskData -from .data import TextData -from .data import VideoData -from .data import LlmPromptResponseCreationData -from .data import LlmPromptCreationData -from .data import LlmResponseCreationData - -from .label import Label -from .collection import LabelList -from .collection import LabelGenerator - -from .metrics import ScalarMetric -from .metrics import ScalarMetricAggregation -from .metrics import ConfusionMatrixMetric -from .metrics import ConfusionMatrixAggregation -from .metrics import ScalarMetricValue -from .metrics import ConfusionMatrixMetricValue - -from .data.tiled_image import EPSG -from .data.tiled_image import EPSGTransformer -from .data.tiled_image import TiledBounds -from .data.tiled_image import TiledImageData -from .data.tiled_image import TileLayer - ----- -labelbox/data/annotation_types/types.py -import sys -from typing import Generic, TypeVar, Any - -from typing_extensions import Annotated -from packaging import version -import numpy as np - -from labelbox import pydantic_compat - -Cuid = Annotated[str, pydantic_compat.Field(min_length=25, max_length=25)] - -DType = TypeVar('DType') -DShape = TypeVar('DShape') - - -class _TypedArray(np.ndarray, Generic[DType, DShape]): - - @classmethod - def __get_validators__(cls): - yield cls.validate - - @classmethod - def validate(cls, val, field: pydantic_compat.ModelField): - if not isinstance(val, np.ndarray): - raise TypeError(f"Expected numpy array. Found {type(val)}") - - if sys.version_info.minor > 6: - actual_dtype = field.sub_fields[-1].type_.__args__[0] - else: - actual_dtype = field.sub_fields[-1].type_.__values__[0] - - if val.dtype != actual_dtype: - raise TypeError( - f"Expected numpy array have type {actual_dtype}. Found {val.dtype}" - ) - return val - - -if version.parse(np.__version__) >= version.parse('1.25.0'): - from typing import GenericAlias - TypedArray = GenericAlias(_TypedArray, (Any, DType)) -elif version.parse(np.__version__) >= version.parse('1.23.0'): - from numpy._typing import _GenericAlias - TypedArray = _GenericAlias(_TypedArray, (Any, DType)) -elif version.parse('1.22.0') <= version.parse( - np.__version__) < version.parse('1.23.0'): - from numpy.typing import _GenericAlias - TypedArray = _GenericAlias(_TypedArray, (Any, DType)) -else: - TypedArray = _TypedArray[Any, DType] - ----- -labelbox/data/annotation_types/feature.py -from typing import Optional - -from labelbox import pydantic_compat - -from .types import Cuid - - -class FeatureSchema(pydantic_compat.BaseModel): - """ - Class that represents a feature schema. - Could be a annotation, a subclass, or an option. - Schema ids might not be known when constructing these objects so both a name and schema id are valid. - """ - name: Optional[str] = None - feature_schema_id: Optional[Cuid] = None - - @pydantic_compat.root_validator - def must_set_one(cls, values): - if values['feature_schema_id'] is None and values['name'] is None: - raise ValueError( - "Must set either feature_schema_id or name for all feature schemas" - ) - return values - - def dict(self, *args, **kwargs): - res = super().dict(*args, **kwargs) - if 'name' in res and res['name'] is None: - res.pop('name') - if 'featureSchemaId' in res and res['featureSchemaId'] is None: - res.pop('featureSchemaId') - return res - ----- -labelbox/data/annotation_types/label.py -from collections import defaultdict -from typing import Any, Callable, Dict, List, Union, Optional -import warnings - -from labelbox import pydantic_compat - -import labelbox -from labelbox.data.annotation_types.data.tiled_image import TiledImageData -from labelbox.schema import ontology -from .annotation import ClassificationAnnotation, ObjectAnnotation -from .relationship import RelationshipAnnotation -from .classification import ClassificationAnswer -from .data import AudioData, ConversationData, DicomData, DocumentData, HTMLData, ImageData, MaskData, TextData, VideoData, LlmPromptCreationData, LlmPromptResponseCreationData, LlmResponseCreationData -from .geometry import Mask -from .metrics import ScalarMetric, ConfusionMatrixMetric -from .types import Cuid -from .video import VideoClassificationAnnotation -from .video import VideoObjectAnnotation, VideoMaskAnnotation -from ..ontology import get_feature_schema_lookup - -DataType = Union[VideoData, ImageData, TextData, TiledImageData, AudioData, - ConversationData, DicomData, DocumentData, HTMLData, - LlmPromptCreationData, LlmPromptResponseCreationData, - LlmResponseCreationData] - - -class Label(pydantic_compat.BaseModel): - """Container for holding data and annotations - - >>> Label( - >>> data = ImageData(url = "http://my-img.jpg"), - >>> annotations = [ - >>> ObjectAnnotation( - >>> value = Point(x = 10, y = 10), - >>> name = "target" - >>> ) - >>> ] - >>> ) - - Args: - uid: Optional Label Id in Labelbox - data: Data of Label, Image, Video, Text - annotations: List of Annotations in the label - extra: additional context - """ - uid: Optional[Cuid] = None - data: DataType - annotations: List[Union[ClassificationAnnotation, ObjectAnnotation, - VideoMaskAnnotation, ScalarMetric, - ConfusionMatrixMetric, - RelationshipAnnotation]] = [] - extra: Dict[str, Any] = {} - - def object_annotations(self) -> List[ObjectAnnotation]: - return self._get_annotations_by_type(ObjectAnnotation) - - def classification_annotations(self) -> List[ClassificationAnnotation]: - return self._get_annotations_by_type(ClassificationAnnotation) - - def _get_annotations_by_type(self, annotation_type): - return [ - annot for annot in self.annotations - if isinstance(annot, annotation_type) - ] - - def frame_annotations( - self - ) -> Dict[str, Union[VideoObjectAnnotation, VideoClassificationAnnotation]]: - frame_dict = defaultdict(list) - for annotation in self.annotations: - if isinstance( - annotation, - (VideoObjectAnnotation, VideoClassificationAnnotation)): - frame_dict[annotation.frame].append(annotation) - return frame_dict - - def add_url_to_data(self, signer) -> "Label": - """ - Creates signed urls for the data - Only uploads url if one doesn't already exist. - - Args: - signer: A function that accepts bytes and returns a signed url. - Returns: - Label with updated references to new data url - """ - self.data.create_url(signer) - return self - - def add_url_to_masks(self, signer) -> "Label": - """ - Creates signed urls for all masks in the Label. - Multiple masks can reference the same MaskData mask so this makes sure we only upload that url once. - Only uploads url if one doesn't already exist. - - Args: - signer: A function that accepts bytes and returns a signed url. - Returns: - Label with updated references to new mask url - """ - masks = [] - for annotation in self.annotations: - # Allows us to upload shared masks once - if isinstance(annotation.value, Mask): - in_list = False - for mask in masks: - if annotation.value.mask is mask: - in_list = True - if not in_list: - masks.append(annotation.value.mask) - for mask in masks: - mask.create_url(signer) - return self - - def create_data_row(self, dataset: "labelbox.Dataset", - signer: Callable[[bytes], str]) -> "Label": - """ - Creates a data row and adds to the given dataset. - Updates the label's data object to have the same external_id and uid as the data row. - - Args: - dataset: labelbox dataset object to add the new data row to - signer: A function that accepts bytes and returns a signed url. - Returns: - Label with updated references to new data row - """ - args = {'row_data': self.data.create_url(signer)} - if self.data.external_id is not None: - args.update({'external_id': self.data.external_id}) - - if self.data.uid is None: - data_row = dataset.create_data_row(**args) - self.data.uid = data_row.uid - self.data.external_id = data_row.external_id - return self - - def assign_feature_schema_ids( - self, ontology_builder: ontology.OntologyBuilder) -> "Label": - """ - Adds schema ids to all FeatureSchema objects in the Labels. - - Args: - ontology_builder: The ontology that matches the feature names assigned to objects in this dataset - Returns: - Label. useful for chaining these modifying functions - - Note: You can now import annotations using names directly without having to lookup schema_ids - """ - warnings.warn("This method is deprecated and will be " - "removed in a future release. Feature schema ids" - " are no longer required for importing.") - tool_lookup, classification_lookup = get_feature_schema_lookup( - ontology_builder) - for annotation in self.annotations: - if isinstance(annotation, ClassificationAnnotation): - self._assign_or_raise(annotation, classification_lookup) - self._assign_option(annotation, classification_lookup) - elif isinstance(annotation, ObjectAnnotation): - self._assign_or_raise(annotation, tool_lookup) - for classification in annotation.classifications: - self._assign_or_raise(classification, classification_lookup) - self._assign_option(classification, classification_lookup) - else: - raise TypeError( - f"Unexpected type found for annotation. {type(annotation)}") - return self - - def _assign_or_raise(self, annotation, lookup: Dict[str, str]) -> None: - if annotation.feature_schema_id is not None: - return - - feature_schema_id = lookup.get(annotation.name) - if feature_schema_id is None: - raise ValueError(f"No tool matches name {annotation.name}. " - f"Must be one of {list(lookup.keys())}.") - annotation.feature_schema_id = feature_schema_id - - def _assign_option(self, classification: ClassificationAnnotation, - lookup: Dict[str, str]) -> None: - if isinstance(classification.value.answer, str): - pass - elif isinstance(classification.value.answer, ClassificationAnswer): - self._assign_or_raise(classification.value.answer, lookup) - elif isinstance(classification.value.answer, list): - for answer in classification.value.answer: - self._assign_or_raise(answer, lookup) - else: - raise TypeError( - f"Unexpected type for answer found. {type(classification.value.answer)}" - ) - - @pydantic_compat.validator("annotations", pre=True) - def validate_union(cls, value): - supported = tuple([ - field.type_ - for field in cls.__fields__['annotations'].sub_fields[0].sub_fields - ]) - if not isinstance(value, list): - raise TypeError(f"Annotations must be a list. Found {type(value)}") - - for v in value: - if not isinstance(v, supported): - raise TypeError( - f"Annotations should be a list containing the following classes : {supported}. Found {type(v)}" - ) - return value - ----- -labelbox/data/annotation_types/collection.py -import logging -from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Callable, Generator, Iterable, Union, Optional -from uuid import uuid4 -import warnings - -from tqdm import tqdm - -from labelbox.schema import ontology -from labelbox.orm.model import Entity -from ..ontology import get_classifications, get_tools -from ..generator import PrefetchGenerator -from .label import Label - -logger = logging.getLogger(__name__) - - -class LabelList: - """ - A container for interacting with a collection of labels. - Less memory efficient than LabelGenerator but more performant and convenient to use. - Use on smaller datasets. - """ - - def __init__(self, data: Optional[Iterable[Label]] = None): - warnings.warn("LabelList is deprecated and will be " - "removed in a future release.") - - if data is None: - self._data = [] - elif isinstance(data, Label): - self._data = [data] - else: - self._data = data - self._index = 0 - - def assign_feature_schema_ids( - self, ontology_builder: "ontology.OntologyBuilder") -> "LabelList": - """ - Adds schema ids to all FeatureSchema objects in the Labels. - - Args: - ontology_builder: The ontology that matches the feature names assigned to objects in this LabelList - Returns: - LabelList. useful for chaining these modifying functions - - Note: You can now import annotations using names directly without having to lookup schema_ids - """ - warnings.warn("This method is deprecated and will be " - "removed in a future release. Feature schema ids" - " are no longer required for importing.") - for label in self._data: - label.assign_feature_schema_ids(ontology_builder) - return self - - def add_to_dataset(self, - dataset: "Entity.Dataset", - signer: Callable[[bytes], str], - max_concurrency=20) -> "LabelList": - """ - Creates data rows from each labels data object and attaches the data to the given dataset. - Updates the label's data object to have the same external_id and uid as the data row. - It is reccomended to create a new dataset if memory is a concern because all dataset data rows are exported to make this faster. - Also note that this relies on exported data that it cached. - So this will not work on the same dataset more frequently than every 30 min. - The workaround is creating a new dataset each time this function is used. - - Args: - dataset: labelbox dataset object to add the new data row to - signer: A function that accepts bytes and returns a signed url. - Returns: - LabelList with updated references to new data rows - """ - self._ensure_unique_external_ids() - self.add_url_to_data(signer, max_concurrency=max_concurrency) - upload_task = dataset.create_data_rows([{ - 'row_data': label.data.url, - 'external_id': label.data.external_id - } for label in self._data]) - upload_task.wait_till_done() - - data_row_lookup = { - data_row.external_id: data_row.uid - for data_row in dataset.export_data_rows() - } - for label in self._data: - label.data.uid = data_row_lookup[label.data.external_id] - return self - - def add_url_to_masks(self, signer, max_concurrency=20) -> "LabelList": - """ - Creates signed urls for all masks in the LabelList. - Multiple masks objects can reference the same MaskData so this makes sure we only upload that url once. - Only uploads url if one doesn't already exist. - - Args: - signer: A function that accepts bytes and returns a signed url. - max_concurrency: how many threads to use for uploading. - Should be balanced to match the signing services capabilities. - Returns: - LabelList with updated references to the new mask urls - """ - for row in self._apply_threaded( - [label.add_url_to_masks for label in self._data], max_concurrency, - signer): - ... - return self - - def add_url_to_data(self, signer, max_concurrency=20) -> "LabelList": - """ - Creates signed urls for the data - Only uploads url if one doesn't already exist. - - Args: - signer: A function that accepts bytes and returns a signed url. - max_concurrency: how many threads to use for uploading. - Should be balanced to match the signing services capabilities. - Returns: - LabelList with updated references to the new data urls - """ - for row in self._apply_threaded( - [label.add_url_to_data for label in self._data], max_concurrency, - signer): - ... - return self - - def get_ontology(self) -> ontology.OntologyBuilder: - classifications = [] - tools = [] - for label in self._data: - tools = get_tools(label.object_annotations(), tools) - classifications = get_classifications( - label.classification_annotations(), classifications) - return ontology.OntologyBuilder(tools=tools, - classifications=classifications) - - def _ensure_unique_external_ids(self) -> None: - external_ids = set() - for label in self._data: - if label.data.external_id is None: - label.data.external_id = str(uuid4()) - else: - if label.data.external_id in external_ids: - raise ValueError( - f"External ids must be unique for bulk uploading. Found {label.data.external_id} more than once." - ) - external_ids.add(label.data.external_id) - - def append(self, label: Label) -> None: - self._data.append(label) - - def __iter__(self) -> "LabelList": - self._index = 0 - return self - - def __next__(self) -> Label: - if self._index == len(self._data): - self._index = 0 - raise StopIteration - - value = self._data[self._index] - self._index += 1 - return value - - def __len__(self) -> int: - return len(self._data) - - def __getitem__(self, idx: int) -> Label: - return self._data[idx] - - def _apply_threaded(self, fns, max_concurrency, *args): - futures = [] - with ThreadPoolExecutor(max_workers=max_concurrency) as executor: - for fn in fns: - futures.append(executor.submit(fn, *args)) - for future in tqdm(as_completed(futures)): - yield future.result() - - -class LabelGenerator(PrefetchGenerator): - """ - A container for interacting with a large collection of labels. - For a small number of labels, just use a list of Label objects. - """ - - def __init__(self, data: Generator[Label, None, None], *args, **kwargs): - self._fns = {} - super().__init__(data, *args, **kwargs) - - def as_list(self) -> "LabelList": - warnings.warn("This method is deprecated and will be " - "removed in a future release. LabelList" - " class will be deprecated.") - return LabelList(data=list(self)) - - def assign_feature_schema_ids( - self, - ontology_builder: "ontology.OntologyBuilder") -> "LabelGenerator": - - def _assign_ids(label: Label): - label.assign_feature_schema_ids(ontology_builder) - return label - - warnings.warn("This method is deprecated and will be " - "removed in a future release. Feature schema ids" - " are no longer required for importing.") - self._fns['assign_feature_schema_ids'] = _assign_ids - return self - - def add_url_to_data(self, signer: Callable[[bytes], - str]) -> "LabelGenerator": - """ - Creates signed urls for the data - Only uploads url if one doesn't already exist. - - Args: - signer: A function that accepts bytes and returns a signed url. - Returns: - LabelGenerator that signs urls as data is accessed - """ - - def _add_url_to_data(label: Label): - label.add_url_to_data(signer) - return label - - self._fns['add_url_to_data'] = _add_url_to_data - return self - - def add_to_dataset(self, dataset: "Entity.Dataset", - signer: Callable[[bytes], str]) -> "LabelGenerator": - """ - Creates data rows from each labels data object and attaches the data to the given dataset. - Updates the label's data object to have the same external_id and uid as the data row. - - This is a lot slower than LabelList.add_to_dataset but also more memory efficient. - - Args: - dataset: labelbox dataset object to add the new data row to - signer: A function that accepts bytes and returns a signed url. - Returns: - LabelGenerator that updates references to the new data rows as data is accessed - """ - - def _add_to_dataset(label: Label): - label.create_data_row(dataset, signer) - return label - - self._fns['assign_datarow_ids'] = _add_to_dataset - return self - - def add_url_to_masks(self, signer: Callable[[bytes], - str]) -> "LabelGenerator": - """ - Creates signed urls for all masks in the LabelGenerator. - Multiple masks can reference the same MaskData so this makes sure we only upload that url once. - Only uploads url if one doesn't already exist. - - Args: - signer: A function that accepts bytes and returns a signed url. - max_concurrency: how many threads to use for uploading. - Should be balanced to match the signing services capabilities. - Returns: - LabelGenerator that updates references to the new mask urls as data is accessed - """ - - def _add_url_to_masks(label: Label): - label.add_url_to_masks(signer) - return label - - self._fns['add_url_to_masks'] = _add_url_to_masks - return self - - def register_background_fn(self, fn: Callable[[Label], Label], - name: str) -> "LabelGenerator": - """ - Allows users to add arbitrary io functions to the generator. - These functions will be exectuted in parallel and added to a prefetch queue. - - Args: - fn: Callable that modifies a label and then returns the same label - - For performance reasons, this function shouldn't run if the object already has the desired state. - name: Register the name of the function. If the name already exists, then the function will be replaced. - """ - self._fns[name] = fn - return self - - def __iter__(self): - return self - - def _process(self, value): - for fn in self._fns.values(): - value = fn(value) - return value - - def __next__(self): - """ - Double checks that all values have been set. - Items could have been processed before any of these modifying functions are called. - None of these functions do anything if run more than once so the cost is minimal. - """ - value = super().__next__() - return self._process(value) - - -LabelCollection = Union[LabelGenerator, Iterable[Label]] - ----- -labelbox/data/annotation_types/base_annotation.py -import abc -from uuid import UUID, uuid4 -from typing import Any, Dict, Optional -from labelbox import pydantic_compat - -from .feature import FeatureSchema - - -class BaseAnnotation(FeatureSchema, abc.ABC): - """ Base annotation class. Shouldn't be directly instantiated - """ - _uuid: Optional[UUID] = pydantic_compat.PrivateAttr() - extra: Dict[str, Any] = {} - - def __init__(self, **data): - super().__init__(**data) - extra_uuid = data.get("extra", {}).get("uuid") - self._uuid = data.get("_uuid") or extra_uuid or uuid4() - ----- -labelbox/data/annotation_types/video.py -from enum import Enum -from typing import List, Optional, Tuple - -from labelbox import pydantic_compat -from labelbox.data.annotation_types.annotation import ClassificationAnnotation, ObjectAnnotation - -from labelbox.data.annotation_types.annotation import ClassificationAnnotation, ObjectAnnotation -from labelbox.data.annotation_types.feature import FeatureSchema -from labelbox.data.mixins import ConfidenceNotSupportedMixin, CustomMetricsNotSupportedMixin -from labelbox.utils import _CamelCaseMixin, is_valid_uri - - -class VideoClassificationAnnotation(ClassificationAnnotation): - """Video classification - Args: - name (Optional[str]) - feature_schema_id (Optional[Cuid]) - value (Union[Text, Checklist, Radio, Dropdown]) - frame (int): The frame index that this annotation corresponds to - segment_id (Optional[Int]): Index of video segment this annotation belongs to - extra (Dict[str, Any]) - """ - frame: int - segment_index: Optional[int] = None - - -class VideoObjectAnnotation(ObjectAnnotation, ConfidenceNotSupportedMixin, - CustomMetricsNotSupportedMixin): - """Video object annotation - >>> VideoObjectAnnotation( - >>> keyframe=True, - >>> frame=10, - >>> value=Rectangle( - >>> start=Point(x=0, y=0), - >>> end=Point(x=1, y=1) - >>> ), - >>> feature_schema_id="my-feature-schema-id" - >>> ) - Args: - name (Optional[str]) - feature_schema_id (Optional[Cuid]) - value (Geometry) - frame (Int): The frame index that this annotation corresponds to - keyframe (bool): Whether or not this annotation was a human generated or interpolated annotation - segment_id (Optional[Int]): Index of video segment this annotation belongs to - classifications (List[ClassificationAnnotation]) = [] - extra (Dict[str, Any]) - """ - frame: int - keyframe: bool - segment_index: Optional[int] = None - - -class GroupKey(Enum): - """Group key for DICOM annotations - """ - AXIAL = "axial" - SAGITTAL = "sagittal" - CORONAL = "coronal" - - -class DICOMObjectAnnotation(VideoObjectAnnotation): - """DICOM object annotation - >>> DICOMObjectAnnotation( - >>> name="dicom_polyline", - >>> frame=2, - >>> value=lb_types.Line(points = [ - >>> lb_types.Point(x=680, y=100), - >>> lb_types.Point(x=100, y=190), - >>> lb_types.Point(x=190, y=220) - >>> ]), - >>> segment_index=0, - >>> keyframe=True, - >>> Group_key=GroupKey.AXIAL - >>> ) - Args: - name (Optional[str]) - feature_schema_id (Optional[Cuid]) - value (Geometry) - group_key (GroupKey) - frame (Int): The frame index that this annotation corresponds to - keyframe (bool): Whether or not this annotation was a human generated or interpolated annotation - segment_id (Optional[Int]): Index of video segment this annotation belongs to - classifications (List[ClassificationAnnotation]) = [] - extra (Dict[str, Any]) - """ - group_key: GroupKey - - -class MaskFrame(_CamelCaseMixin, pydantic_compat.BaseModel): - index: int - instance_uri: Optional[str] = None - im_bytes: Optional[bytes] = None - - @pydantic_compat.root_validator() - def validate_args(cls, values): - im_bytes = values.get("im_bytes") - instance_uri = values.get("instance_uri") - - if im_bytes == instance_uri == None: - raise ValueError("One of `instance_uri`, `im_bytes` required.") - return values - - @pydantic_compat.validator("instance_uri") - def validate_uri(cls, v): - if not is_valid_uri(v): - raise ValueError(f"{v} is not a valid uri") - return v - - -class MaskInstance(_CamelCaseMixin, FeatureSchema): - color_rgb: Tuple[int, int, int] - name: str - - -class VideoMaskAnnotation(pydantic_compat.BaseModel): - """Video mask annotation - >>> VideoMaskAnnotation( - >>> frames=[ - >>> MaskFrame(index=1, instance_uri='https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA'), - >>> MaskFrame(index=5, instance_uri='https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys1%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA'), - >>> ], - >>> instances=[ - >>> MaskInstance(color_rgb=(0, 0, 255), name="mask1"), - >>> MaskInstance(color_rgb=(0, 255, 0), name="mask2"), - >>> MaskInstance(color_rgb=(255, 0, 0), name="mask3") - >>> ] - >>> ) - """ - frames: List[MaskFrame] - instances: List[MaskInstance] - - -class DICOMMaskAnnotation(VideoMaskAnnotation): - """DICOM mask annotation - >>> DICOMMaskAnnotation( - >>> name="dicom_mask", - >>> group_key=GroupKey.AXIAL, - >>> frames=[ - >>> MaskFrame(index=1, instance_uri='https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA'), - >>> MaskFrame(index=5, instance_uri='https://storage.labelbox.com/cjhfn5y6s0pk507024nz1ocys1%2F1d60856c-59b7-3060-2754-83f7e93e0d01-1?Expires=1666901963361&KeyName=labelbox-assets-key-3&Signature=t-2s2DB4YjFuWEFak0wxYqfBfZA'), - >>> ], - >>> instances=[ - >>> MaskInstance(color_rgb=(0, 0, 255), name="mask1"), - >>> MaskInstance(color_rgb=(0, 255, 0), name="mask2"), - >>> MaskInstance(color_rgb=(255, 0, 0), name="mask3") - >>> ] - >>> ) - """ - group_key: GroupKey - ----- -labelbox/data/annotation_types/classification/classification.py -from typing import Any, Dict, List, Union, Optional -import warnings -from labelbox.data.annotation_types.base_annotation import BaseAnnotation - -from labelbox.data.mixins import ConfidenceMixin, CustomMetricsMixin - -try: - from typing import Literal -except: - from typing_extensions import Literal - -from labelbox import pydantic_compat -from ..feature import FeatureSchema - - -# TODO: Replace when pydantic adds support for unions that don't coerce types -class _TempName(ConfidenceMixin, pydantic_compat.BaseModel): - name: str - - def dict(self, *args, **kwargs): - res = super().dict(*args, **kwargs) - res.pop('name') - return res - - -class ClassificationAnswer(FeatureSchema, ConfidenceMixin, CustomMetricsMixin): - """ - - Represents a classification option. - - Because it inherits from FeatureSchema - the option can be represented with either the name or feature_schema_id - - - The keyframe arg only applies to video classifications. - Each answer can have a keyframe independent of the others. - So unlike object annotations, classification annotations - track keyframes at a classification answer level. - """ - extra: Dict[str, Any] = {} - keyframe: Optional[bool] = None - classifications: List['ClassificationAnnotation'] = [] - - def dict(self, *args, **kwargs) -> Dict[str, str]: - res = super().dict(*args, **kwargs) - if res['keyframe'] is None: - res.pop('keyframe') - if res['classifications'] == []: - res.pop('classifications') - return res - - -class Radio(ConfidenceMixin, CustomMetricsMixin, pydantic_compat.BaseModel): - """ A classification with only one selected option allowed - - >>> Radio(answer = ClassificationAnswer(name = "dog")) - - """ - answer: ClassificationAnswer - - -class Checklist(_TempName): - """ A classification with many selected options allowed - - >>> Checklist(answer = [ClassificationAnswer(name = "cloudy")]) - - """ - name: Literal["checklist"] = "checklist" - answer: List[ClassificationAnswer] - - -class Text(ConfidenceMixin, CustomMetricsMixin, pydantic_compat.BaseModel): - """ Free form text - - >>> Text(answer = "some text answer") - - """ - answer: str - - -class Dropdown(_TempName): - """ - - A classification with many selected options allowed . - - This is not currently compatible with MAL. - - Deprecation Notice: Dropdown classification is deprecated and will be - removed in a future release. Dropdown will also - no longer be able to be created in the Editor on 3/31/2022. - """ - name: Literal["dropdown"] = "dropdown" - answer: List[ClassificationAnswer] - - def __init__(self, **data: Any): - super().__init__(**data) - warnings.warn("Dropdown classification is deprecated and will be " - "removed in a future release") - - -class ClassificationAnnotation(BaseAnnotation, ConfidenceMixin, - CustomMetricsMixin): - """Classification annotations (non localized) - - >>> ClassificationAnnotation( - >>> value=Text(answer="my caption message"), - >>> feature_schema_id="my-feature-schema-id" - >>> ) - - Args: - name (Optional[str]) - classifications (Optional[List[ClassificationAnnotation]]): Optional sub classification of the annotation - feature_schema_id (Optional[Cuid]) - value (Union[Text, Checklist, Radio, Dropdown]) - extra (Dict[str, Any]) - """ - - value: Union[Text, Checklist, Radio, Dropdown] - message_id: Optional[str] = None - - -ClassificationAnswer.update_forward_refs() - ----- -labelbox/data/annotation_types/classification/__init__.py -from .classification import (Checklist, ClassificationAnswer, Dropdown, Radio, - Text) - ----- -labelbox/data/annotation_types/metrics/scalar.py -from typing import Dict, Optional, Union -from enum import Enum - -from .base import ConfidenceValue, BaseMetric - -from labelbox import pydantic_compat - -ScalarMetricValue = pydantic_compat.confloat(ge=0, le=100_000_000) -ScalarMetricConfidenceValue = Dict[ConfidenceValue, ScalarMetricValue] - - -class ScalarMetricAggregation(Enum): - ARITHMETIC_MEAN = "ARITHMETIC_MEAN" - GEOMETRIC_MEAN = "GEOMETRIC_MEAN" - HARMONIC_MEAN = "HARMONIC_MEAN" - SUM = "SUM" - - -RESERVED_METRIC_NAMES = ('true_positive_count', 'false_positive_count', - 'true_negative_count', 'false_negative_count', - 'precision', 'recall', 'f1', 'iou') - - -class ScalarMetric(BaseMetric): - """ Class representing scalar metrics - - For backwards compatibility, metric_name is optional. - The metric_name will be set to a default name in the editor if it is not set. - This is not recommended and support for empty metric_name fields will be removed. - aggregation will be ignored wihtout providing a metric name. - """ - metric_name: Optional[str] = None - value: Union[ScalarMetricValue, ScalarMetricConfidenceValue] - aggregation: ScalarMetricAggregation = ScalarMetricAggregation.ARITHMETIC_MEAN - - @pydantic_compat.validator('metric_name') - def validate_metric_name(cls, name: Union[str, None]): - if name is None: - return None - clean_name = name.lower().strip() - if clean_name in RESERVED_METRIC_NAMES: - raise ValueError(f"`{clean_name}` is a reserved metric name. " - "Please provide another value for `metric_name`.") - return name - - def dict(self, *args, **kwargs): - res = super().dict(*args, **kwargs) - if res.get('metric_name') is None: - res.pop('aggregation') - return res - ----- -labelbox/data/annotation_types/metrics/__init__.py -from .scalar import ScalarMetric, ScalarMetricAggregation, ScalarMetricValue -from .confusion_matrix import ConfusionMatrixMetric, ConfusionMatrixAggregation, ConfusionMatrixMetricValue - ----- -labelbox/data/annotation_types/metrics/confusion_matrix.py -from enum import Enum -from typing import Tuple, Dict, Union - -from labelbox import pydantic_compat - -from .base import ConfidenceValue, BaseMetric - -Count = pydantic_compat.conint(ge=0, le=1e10) - -ConfusionMatrixMetricValue = Tuple[Count, Count, Count, Count] -ConfusionMatrixMetricConfidenceValue = Dict[ConfidenceValue, - ConfusionMatrixMetricValue] - - -class ConfusionMatrixAggregation(Enum): - CONFUSION_MATRIX = "CONFUSION_MATRIX" - - -class ConfusionMatrixMetric(BaseMetric): - """ Class representing confusion matrix metrics. - - In the editor, this provides precision, recall, and f-scores. - This should be used over multiple scalar metrics so that aggregations are accurate. - - Value should be a tuple representing: - [True Positive Count, False Positive Count, True Negative Count, False Negative Count] - - aggregation cannot be adjusted for confusion matrix metrics. - """ - metric_name: str - value: Union[ConfusionMatrixMetricValue, - ConfusionMatrixMetricConfidenceValue] - aggregation: ConfusionMatrixAggregation = pydantic_compat.Field( - ConfusionMatrixAggregation.CONFUSION_MATRIX, const=True) - ----- -labelbox/data/annotation_types/metrics/base.py -from abc import ABC -from typing import Dict, Optional, Any, Union - -from labelbox import pydantic_compat - -ConfidenceValue = pydantic_compat.confloat(ge=0, le=1) - -MIN_CONFIDENCE_SCORES = 2 -MAX_CONFIDENCE_SCORES = 15 - - -class BaseMetric(pydantic_compat.BaseModel, ABC): - value: Union[Any, Dict[float, Any]] - feature_name: Optional[str] = None - subclass_name: Optional[str] = None - extra: Dict[str, Any] = {} - - def dict(self, *args, **kwargs): - res = super().dict(*args, **kwargs) - return {k: v for k, v in res.items() if v is not None} - - @pydantic_compat.validator('value') - def validate_value(cls, value): - if isinstance(value, Dict): - if not (MIN_CONFIDENCE_SCORES <= len(value) <= - MAX_CONFIDENCE_SCORES): - raise pydantic_compat.ValidationError([ - pydantic_compat.ErrorWrapper(ValueError( - "Number of confidence scores must be greater" - f" than or equal to {MIN_CONFIDENCE_SCORES} and" - f" less than or equal to {MAX_CONFIDENCE_SCORES}. Found {len(value)}" - ), - loc='value') - ], cls) - return value - ----- -labelbox/data/annotation_types/ner/__init__.py -from .conversation_entity import ConversationEntity -from .document_entity import DocumentEntity, DocumentTextSelection -from .text_entity import TextEntity - ----- -labelbox/data/annotation_types/ner/conversation_entity.py -from labelbox.data.annotation_types.ner.text_entity import TextEntity -from labelbox.utils import _CamelCaseMixin - - -class ConversationEntity(TextEntity, _CamelCaseMixin): - """ Represents a text entity """ - message_id: str ----- -labelbox/data/annotation_types/ner/document_entity.py -from typing import List - -from labelbox import pydantic_compat -from labelbox.utils import _CamelCaseMixin - - -class DocumentTextSelection(_CamelCaseMixin, pydantic_compat.BaseModel): - token_ids: List[str] - group_id: str - page: int - - @pydantic_compat.validator("page") - def validate_page(cls, v): - if v < 1: - raise ValueError("Page must be greater than 1") - return v - - -class DocumentEntity(_CamelCaseMixin, pydantic_compat.BaseModel): - """ Represents a text entity """ - text_selections: List[DocumentTextSelection] - ----- -labelbox/data/annotation_types/ner/text_entity.py -from typing import Dict, Any - -from labelbox import pydantic_compat - - -class TextEntity(pydantic_compat.BaseModel): - """ Represents a text entity """ - start: int - end: int - extra: Dict[str, Any] = {} - - @pydantic_compat.root_validator - def validate_start_end(cls, values): - if 'start' in values and 'end' in values: - if (isinstance(values['start'], int) and - values['start'] > values['end']): - raise ValueError( - "Location end must be greater or equal to start") - return values - ----- -labelbox/data/annotation_types/geometry/mask.py -from typing import Callable, Optional, Tuple, Union, Dict, List - -import numpy as np -import cv2 - -from shapely.geometry import MultiPolygon, Polygon - -from ..data import MaskData -from .geometry import Geometry - -from labelbox import pydantic_compat - - -class Mask(Geometry): - """Mask used to represent a single class in a larger segmentation mask - - Example of a mutually exclusive class - - >>> arr = MaskData.from_2D_arr([ - >>> [0, 0, 0], - >>> [1, 1, 1], - >>> [2, 2, 2], - >>>]) - >>> annotations = [ - >>> ObjectAnnotation(value=Mask(mask=arr, color=1), name="dog"), - >>> ObjectAnnotation(value=Mask(mask=arr, color=2), name="cat"), - >>>] - - Args: - mask (MaskData): An object containing the actual mask, `MaskData` can - be shared across multiple `Masks` to more efficiently store data - for mutually exclusive segmentations. - color (Tuple[uint8, uint8, uint8]): RGB color or a single value - indicating the values of the class in the `MaskData` - """ - - mask: MaskData - color: Union[Tuple[int, int, int], int] - - @property - def geometry(self) -> Dict[str, Tuple[int, int, int]]: - mask = self.draw(color=1) - contours, hierarchy = cv2.findContours(image=mask, - mode=cv2.RETR_TREE, - method=cv2.CHAIN_APPROX_NONE) - - holes = [] - external_contours = [] - for i in range(len(contours)): - if hierarchy[0, i, 3] != -1: - #determined to be a hole based on contour hierarchy - holes.append(contours[i]) - else: - external_contours.append(contours[i]) - - external_polygons = self._extract_polygons_from_contours( - external_contours) - holes = self._extract_polygons_from_contours(holes) - - if not external_polygons.is_valid: - external_polygons = external_polygons.buffer(0) - - if not holes.is_valid: - holes = holes.buffer(0) - - return external_polygons.difference(holes).__geo_interface__ - - def draw(self, - height: Optional[int] = None, - width: Optional[int] = None, - canvas: Optional[np.ndarray] = None, - color: Optional[Union[int, Tuple[int, int, int]]] = None, - thickness=None) -> np.ndarray: - """Converts the Mask object into a numpy array - - Args: - height (int): Optionally resize mask height before drawing. - width (int): Optionally resize mask width before drawing. - canvas (np.ndarray): Optionall provide a canvas to draw on - color (Union[int, Tuple[int,int,int]]): Color to draw the canvas. - Defaults to using the encoded color in the mask. - int will return the mask as a 1d array - tuple[int,int,int] will return the mask as a 3d array - thickness (None): Unused, exists for a consistent interface. - - Returns: - np.ndarray representing only this object - as opposed to the mask that this object references which might have multiple objects determined by colors - """ - mask = self.mask.value - mask = np.alltrue(mask == self.color, axis=2).astype(np.uint8) - - if height is not None or width is not None: - mask = cv2.resize(mask, - (width or mask.shape[1], height or mask.shape[0])) - - dims = [mask.shape[0], mask.shape[1]] - color = color or self.color - if isinstance(color, (tuple, list)): - dims = dims + [len(color)] - - canvas = canvas if canvas is not None else np.zeros(tuple(dims), - dtype=np.uint8) - canvas[mask.astype(bool)] = color - return canvas - - def _extract_polygons_from_contours(self, contours: List) -> MultiPolygon: - contours = map(np.squeeze, contours) - filtered_contours = filter(lambda contour: len(contour) > 2, contours) - polygons = map(Polygon, filtered_contours) - return MultiPolygon(polygons) - - def create_url(self, signer: Callable[[bytes], str]) -> str: - """ - Update the segmentation mask to have a url. - Only update the mask if it doesn't already have a url - - Args: - signer: A function that accepts bytes and returns a signed url. - Returns: - the url for the mask - """ - return self.mask.create_url(signer) - - @pydantic_compat.validator('color') - def is_valid_color(cls, color): - if isinstance(color, (tuple, list)): - if len(color) == 1: - color = [color[0]] * 3 - if len(color) != 3: - raise ValueError( - "Segmentation colors must be either a (r,g,b) tuple or a single grayscale value" - ) - elif not all([0 <= c <= 255 for c in color]): - raise ValueError( - f"All rgb colors must be between 0 and 255. Found : {color}" - ) - elif not (0 <= color <= 255): - raise ValueError( - f"All rgb colors must be between 0 and 255. Found : {color}") - - return color - ----- -labelbox/data/annotation_types/geometry/line.py -from typing import List, Optional, Union, Tuple - -import geojson -import numpy as np -import cv2 - -from shapely.geometry import LineString as SLineString - -from .point import Point -from .geometry import Geometry - -from labelbox import pydantic_compat - - -class Line(Geometry): - """Line annotation - - Args: - points (List[Point]): A list of `Point` geometries - - >>> Line(points = [Point(x=3,y=4), Point(x=3,y=5)]) - - """ - points: List[Point] - - @property - def geometry(self) -> geojson.MultiLineString: - return geojson.MultiLineString( - [[[point.x, point.y] for point in self.points]]) - - @classmethod - def from_shapely(cls, shapely_obj: SLineString) -> "Line": - """Transforms a shapely object.""" - if not isinstance(shapely_obj, SLineString): - raise TypeError( - f"Expected Shapely Line. Got {shapely_obj.geom_type}") - - obj_coords = shapely_obj.__geo_interface__['coordinates'] - return Line( - points=[Point(x=coords[0], y=coords[1]) for coords in obj_coords]) - - def draw(self, - height: Optional[int] = None, - width: Optional[int] = None, - canvas: Optional[np.ndarray] = None, - color: Union[int, Tuple[int, int, int]] = (255, 255, 255), - thickness: int = 1) -> np.ndarray: - """ - Draw the line onto a 3d mask - Args: - height (int): height of the mask - width (int): width of the mask - thickness (int): How thick to draw the line - color (int): color for the line. - RGB values by default but if a 2D canvas is provided this can set this to an int. - canvas (np.ndarry): Canvas for drawing line on. - Returns: - numpy array representing the mask with the line drawn on it. - """ - canvas = self.get_or_create_canvas(height, width, canvas) - pts = np.array(self.geometry['coordinates']).astype(np.int32) - return cv2.polylines(canvas, - pts, - False, - color=color, - thickness=thickness) - - @pydantic_compat.validator('points') - def is_geom_valid(cls, points): - if len(points) < 2: - raise ValueError( - f"A line must have at least 2 points to be valid. Found {points}" - ) - - return points - ----- -labelbox/data/annotation_types/geometry/polygon.py -from typing import List, Optional, Union, Tuple - -import cv2 -import geojson -import numpy as np - -from shapely.geometry import Polygon as SPolygon - -from .geometry import Geometry -from .point import Point - -from labelbox import pydantic_compat - - -class Polygon(Geometry): - """Polygon geometry - - A polygon is created from a collection of points - - >>> Polygon(points=[Point(x=0, y=0), Point(x=1, y=0), Point(x=1, y=1), Point(x=0, y=0)]) - - Args: - points (List[Point]): List of `Points`, minimum of three points. If you do not - close the polygon (the last point and first point are the same) an additional - point is added to close it. - - """ - points: List[Point] - - @property - def geometry(self) -> geojson.Polygon: - if self.points[0] != self.points[-1]: - self.points.append(self.points[0]) - return geojson.Polygon([[(point.x, point.y) for point in self.points]]) - - @classmethod - def from_shapely(cls, shapely_obj: SPolygon) -> "Polygon": - """Transforms a shapely object.""" - #we only consider 0th index because we only allow for filled polygons - if not isinstance(shapely_obj, SPolygon): - raise TypeError( - f"Expected Shapely Polygon. Got {shapely_obj.geom_type}") - obj_coords = shapely_obj.__geo_interface__['coordinates'][0] - return Polygon( - points=[Point(x=coords[0], y=coords[1]) for coords in obj_coords]) - - def draw(self, - height: Optional[int] = None, - width: Optional[int] = None, - canvas: Optional[np.ndarray] = None, - color: Union[int, Tuple[int, int, int]] = (255, 255, 255), - thickness: int = -1) -> np.ndarray: - """ - Draw the polygon onto a 3d mask - Args: - height (int): height of the mask - width (int): width of the mask - color (int): color for the polygon. - RGB values by default but if a 2D canvas is provided this can set this to an int. - thickness (int): How thick to make the polygon border. -1 fills in the polygon - canvas (np.ndarray): Canvas to draw the polygon on - Returns: - numpy array representing the mask with the polygon drawn on it. - """ - canvas = self.get_or_create_canvas(height, width, canvas) - pts = np.array(self.geometry['coordinates']).astype(np.int32) - if thickness == -1: - return cv2.fillPoly(canvas, pts, color) - return cv2.polylines(canvas, pts, True, color, thickness) - - @pydantic_compat.validator('points') - def is_geom_valid(cls, points): - if len(points) < 3: - raise ValueError( - f"A polygon must have at least 3 points to be valid. Found {points}" - ) - if points[0] != points[-1]: - points.append(points[0]) - - return points - ----- -labelbox/data/annotation_types/geometry/__init__.py -from .line import Line -from .point import Point -from .mask import Mask -from .polygon import Polygon -from .rectangle import Rectangle -from .rectangle import DocumentRectangle -from .rectangle import RectangleUnit -from .geometry import Geometry - ----- -labelbox/data/annotation_types/geometry/point.py -from typing import Optional, Tuple, Union - -import geojson -import numpy as np -import cv2 -from shapely.geometry import Point as SPoint - -from .geometry import Geometry - - -class Point(Geometry): - """Point geometry - - >>> Point(x=0, y=0) - - Args: - x (float) - y (float) - - """ - x: float - y: float - - @property - def geometry(self) -> geojson.Point: - return geojson.Point((self.x, self.y)) - - @classmethod - def from_shapely(cls, shapely_obj: SPoint) -> "Point": - """Transforms a shapely object.""" - if not isinstance(shapely_obj, SPoint): - raise TypeError( - f"Expected Shapely Point. Got {shapely_obj.geom_type}") - - obj_coords = shapely_obj.__geo_interface__['coordinates'] - return Point(x=obj_coords[0], y=obj_coords[1]) - - def draw(self, - height: Optional[int] = None, - width: Optional[int] = None, - canvas: Optional[np.ndarray] = None, - color: Union[int, Tuple[int, int, int]] = (255, 255, 255), - thickness: int = 10) -> np.ndarray: - """ - Draw the point onto a 3d mask - Args: - height (int): height of the mask - width (int): width of the mask - thickness (int): pixel radius of the point - color (int): color for the point. - RGB values by default but if a 2D canvas is provided this can set this to an int. - canvas (np.ndarray): Canvas to draw the point on - Returns: - numpy array representing the mask with the point drawn on it. - """ - canvas = self.get_or_create_canvas(height, width, canvas) - return cv2.circle(canvas, (int(self.x), int(self.y)), - radius=thickness, - color=color, - thickness=-1) - ----- -labelbox/data/annotation_types/geometry/geometry.py -from typing import Dict, Any, Optional, Union, Tuple -from abc import ABC, abstractmethod - -import geojson -import numpy as np -from labelbox import pydantic_compat - -from shapely import geometry as geom - - -class Geometry(pydantic_compat.BaseModel, ABC): - """Abstract base class for geometry objects - """ - extra: Dict[str, Any] = {} - - @property - def shapely( - self - ) -> Union[geom.Point, geom.LineString, geom.Polygon, geom.MultiPoint, - geom.MultiLineString, geom.MultiPolygon]: - return geom.shape(self.geometry) - - def get_or_create_canvas(self, height: Optional[int], width: Optional[int], - canvas: Optional[np.ndarray]) -> np.ndarray: - if canvas is None: - if height is None or width is None: - raise ValueError( - "Must either provide canvas or height and width") - canvas = np.zeros((height, width, 3), dtype=np.uint8) - canvas = np.ascontiguousarray(canvas) - return canvas - - @property - @abstractmethod - def geometry(self) -> geojson: - pass - - @abstractmethod - def draw(self, - height: Optional[int] = None, - width: Optional[int] = None, - canvas: Optional[np.ndarray] = None, - color: Optional[Union[int, Tuple[int, int, int]]] = None, - thickness: Optional[int] = 1) -> np.ndarray: - pass - ----- -labelbox/data/annotation_types/geometry/rectangle.py -from typing import Optional, Union, Tuple -from enum import Enum - -import cv2 -import geojson -import numpy as np - -from shapely.geometry import Polygon as SPolygon - -from .geometry import Geometry -from .point import Point - - -class Rectangle(Geometry): - """Represents a 2d rectangle. Also known as a bounding box - - >>> Rectangle(start=Point(x=0, y=0), end=Point(x=1, y=1)) - - Args: - start (Point): Top left coordinate of the rectangle - end (Point): Bottom right coordinate of the rectangle - """ - start: Point - end: Point - - @property - def geometry(self) -> geojson.geometry.Geometry: - return geojson.Polygon([[ - [self.start.x, self.start.y], - [self.start.x, self.end.y], - [self.end.x, self.end.y], - [self.end.x, self.start.y], - [self.start.x, self.start.y], - ]]) - - @classmethod - def from_shapely(cls, shapely_obj: SPolygon) -> "Rectangle": - """Transforms a shapely object. - - If the provided shape is a non-rectangular polygon, a rectangle will be - returned based on the min and max x,y values.""" - if not isinstance(shapely_obj, SPolygon): - raise TypeError( - f"Expected Shapely Polygon. Got {shapely_obj.geom_type}") - - min_x, min_y, max_x, max_y = shapely_obj.bounds - - start = [min_x, min_y] - end = [max_x, max_y] - - return Rectangle(start=Point(x=start[0], y=start[1]), - end=Point(x=end[0], y=end[1])) - - def draw(self, - height: Optional[int] = None, - width: Optional[int] = None, - canvas: Optional[np.ndarray] = None, - color: Union[int, Tuple[int, int, int]] = (255, 255, 255), - thickness: int = -1) -> np.ndarray: - """ - Draw the rectangle onto a 3d mask - Args: - height (int): height of the mask - width (int): width of the mask - color (int): color for the polygon. - RGB values by default but if a 2D canvas is provided this can set this to an int. - thickness (int): How thick to make the rectangle border. -1 fills in the rectangle - canvas (np.ndarray): Canvas to draw rectangle on - Returns: - numpy array representing the mask with the rectangle drawn on it. - """ - canvas = self.get_or_create_canvas(height, width, canvas) - pts = np.array(self.geometry['coordinates']).astype(np.int32) - if thickness == -1: - return cv2.fillPoly(canvas, pts, color) - return cv2.polylines(canvas, pts, True, color, thickness) - - @classmethod - def from_xyhw(cls, x: float, y: float, h: float, w: float) -> "Rectangle": - """Create Rectangle from x,y, height width format""" - return cls(start=Point(x=x, y=y), end=Point(x=x + w, y=y + h)) - - -class RectangleUnit(Enum): - INCHES = 'INCHES' - PIXELS = 'PIXELS' - POINTS = 'POINTS' - - -class DocumentRectangle(Rectangle): - """Represents a 2d rectangle on a Document - - >>> Rectangle( - >>> start=Point(x=0, y=0), - >>> end=Point(x=1, y=1), - >>> page=4, - >>> unit=RectangleUnits.POINTS - >>> ) - - Args: - start (Point): Top left coordinate of the rectangle - end (Point): Bottom right coordinate of the rectangle - page (int): Page number of the document - unit (RectangleUnits): Units of the rectangle - """ - page: int - unit: RectangleUnit - ----- -labelbox/data/annotation_types/data/conversation.py -from labelbox.typing_imports import Literal -from labelbox.utils import _NoCoercionMixin -from .base_data import BaseData - - -class ConversationData(BaseData): - class_name: Literal["ConversationData"] = "ConversationData" ----- -labelbox/data/annotation_types/data/html.py -from labelbox.typing_imports import Literal -from labelbox.utils import _NoCoercionMixin -from .base_data import BaseData - - -class HTMLData(BaseData, _NoCoercionMixin): - class_name: Literal["HTMLData"] = "HTMLData" ----- -labelbox/data/annotation_types/data/__init__.py -from .audio import AudioData -from .conversation import ConversationData -from .dicom import DicomData -from .document import DocumentData -from .html import HTMLData -from .raster import ImageData -from .raster import MaskData -from .text import TextData -from .video import VideoData -from .llm_prompt_response_creation import LlmPromptResponseCreationData -from .llm_prompt_creation import LlmPromptCreationData -from .llm_response_creation import LlmResponseCreationData ----- -labelbox/data/annotation_types/data/llm_response_creation.py -from labelbox.typing_imports import Literal -from labelbox.utils import _NoCoercionMixin -from .base_data import BaseData - - -class LlmResponseCreationData(BaseData, _NoCoercionMixin): - class_name: Literal["LlmResponseCreationData"] = "LlmResponseCreationData" ----- -labelbox/data/annotation_types/data/tiled_image.py -from functools import lru_cache -import math -import logging -from enum import Enum -from typing import Optional, List, Tuple, Any, Union, Dict, Callable -from concurrent.futures import ThreadPoolExecutor -from io import BytesIO - -import requests -import numpy as np -from google.api_core import retry -from PIL import Image -from pyproj import Transformer -from pygeotile.point import Point as PygeoPoint -from labelbox import pydantic_compat - -from labelbox.data.annotation_types import Rectangle, Point, Line, Polygon -from .base_data import BaseData -from .raster import RasterData - -VALID_LAT_RANGE = range(-90, 90) -VALID_LNG_RANGE = range(-180, 180) -DEFAULT_TMS_TILE_SIZE = 256 -TILE_DOWNLOAD_CONCURRENCY = 4 - -logger = logging.getLogger(__name__) - -VectorTool = Union[Point, Line, Rectangle, Polygon] - - -class EPSG(Enum): - """ Provides the EPSG for tiled image assets that are currently supported. - - SIMPLEPIXEL is Simple that can be used to obtain the pixel space coordinates - - >>> epsg = EPSG() - """ - SIMPLEPIXEL = 1 - EPSG4326 = 4326 - EPSG3857 = 3857 - - -class TiledBounds(pydantic_compat.BaseModel): - """ Bounds for a tiled image asset related to the relevant epsg. - - Bounds should be Point objects. - - >>> bounds = TiledBounds(epsg=EPSG.EPSG4326, - bounds=[ - Point(x=-99.21052827588443, y=19.405662413477728), - Point(x=-99.20534818927473, y=19.400498983095076) - ]) - """ - epsg: EPSG - bounds: List[Point] - - @pydantic_compat.validator('bounds') - def validate_bounds_not_equal(cls, bounds): - first_bound = bounds[0] - second_bound = bounds[1] - - if first_bound.x == second_bound.x or \ - first_bound.y == second_bound.y: - raise ValueError( - f"Bounds on either axes cannot be equal, currently {bounds}") - return bounds - - #validate bounds are within lat,lng range if they are EPSG4326 - @pydantic_compat.root_validator - def validate_bounds_lat_lng(cls, values): - epsg = values.get('epsg') - bounds = values.get('bounds') - - if epsg == EPSG.EPSG4326: - for bound in bounds: - lat, lng = bound.y, bound.x - if int(lng) not in VALID_LNG_RANGE or int( - lat) not in VALID_LAT_RANGE: - raise ValueError(f"Invalid lat/lng bounds. Found {bounds}. " - f"lat must be in {VALID_LAT_RANGE}. " - f"lng must be in {VALID_LNG_RANGE}.") - return values - - -class TileLayer(pydantic_compat.BaseModel): - """ Url that contains the tile layer. Must be in the format: - - https://c.tile.openstreetmap.org/{z}/{x}/{y}.png - - >>> layer = TileLayer( - url="https://c.tile.openstreetmap.org/{z}/{x}/{y}.png", - name="slippy map tile" - ) - """ - url: str - name: Optional[str] = "default" - - def asdict(self) -> Dict[str, str]: - return {"tileLayerUrl": self.url, "name": self.name} - - @pydantic_compat.validator('url') - def validate_url(cls, url): - xyz_format = "/{z}/{x}/{y}" - if xyz_format not in url: - raise ValueError(f"{url} needs to contain {xyz_format}") - return url - - -class TiledImageData(BaseData): - """ Represents tiled imagery - - If specified version is 2, converts bounds from [lng,lat] to [lat,lng] - - Requires the following args: - tile_layer: TileLayer - tile_bounds: TiledBounds - zoom_levels: List[int] - Optional args: - max_native_zoom: int = None - tile_size: Optional[int] - version: int = 2 - alternative_layers: List[TileLayer] - - >>> tiled_image_data = TiledImageData(tile_layer=TileLayer, - tile_bounds=TiledBounds, - zoom_levels=[1, 12]) - """ - tile_layer: TileLayer - tile_bounds: TiledBounds - alternative_layers: List[TileLayer] = [] - zoom_levels: Tuple[int, int] - max_native_zoom: Optional[int] = None - tile_size: Optional[int] = DEFAULT_TMS_TILE_SIZE - version: Optional[int] = 2 - multithread: bool = True - - def __post_init__(self) -> None: - if self.max_native_zoom is None: - self.max_native_zoom = self.zoom_levels[0] - - def asdict(self) -> Dict[str, str]: - return { - "tileLayerUrl": self.tile_layer.url, - "bounds": [[ - self.tile_bounds.bounds[0].x, self.tile_bounds.bounds[0].y - ], [self.tile_bounds.bounds[1].x, self.tile_bounds.bounds[1].y]], - "minZoom": self.zoom_levels[0], - "maxZoom": self.zoom_levels[1], - "maxNativeZoom": self.max_native_zoom, - "epsg": self.tile_bounds.epsg.name, - "tileSize": self.tile_size, - "alternativeLayers": [ - layer.asdict() for layer in self.alternative_layers - ], - "version": self.version - } - - def raster_data(self, - zoom: int = 0, - max_tiles: int = 32, - multithread=True) -> RasterData: - """Converts the tiled image asset into a RasterData object containing an - np.ndarray. - - Uses the minimum zoom provided to render the image. - """ - if self.tile_bounds.epsg == EPSG.SIMPLEPIXEL: - xstart, ystart, xend, yend = self._get_simple_image_params(zoom) - elif self.tile_bounds.epsg == EPSG.EPSG4326: - xstart, ystart, xend, yend = self._get_3857_image_params( - zoom, self.tile_bounds) - elif self.tile_bounds.epsg == EPSG.EPSG3857: - #transform to 4326 - transformer = EPSGTransformer.create_geo_to_geo_transformer( - EPSG.EPSG3857, EPSG.EPSG4326) - transforming_bounds = [ - transformer(self.tile_bounds.bounds[0]), - transformer(self.tile_bounds.bounds[1]) - ] - xstart, ystart, xend, yend = self._get_3857_image_params( - zoom, transforming_bounds) - else: - raise ValueError(f"Unsupported epsg found: {self.tile_bounds.epsg}") - - self._validate_num_tiles(xstart, ystart, xend, yend, max_tiles) - - rounded_tiles, pixel_offsets = list( - zip(*[ - self._tile_to_pixel(pt) for pt in [xstart, ystart, xend, yend] - ])) - - image = self._fetch_image_for_bounds(*rounded_tiles, zoom, multithread) - arr = self._crop_to_bounds(image, *pixel_offsets) - return RasterData(arr=arr) - - @property - def value(self) -> np.ndarray: - """Returns the value of a generated RasterData object. - """ - return self.raster_data(self.zoom_levels[0], - multithread=self.multithread).value - - def _get_simple_image_params(self, - zoom) -> Tuple[float, float, float, float]: - """Computes the x and y tile bounds for fetching an image that - captures the entire labeling region (TiledData.bounds) given a specific zoom - - Simple has different order of x / y than lat / lng because of how leaflet behaves - leaflet reports all points as pixel locations at a zoom of 0 - """ - xend, xstart, yend, ystart = ( - self.tile_bounds.bounds[1].x, - self.tile_bounds.bounds[0].x, - self.tile_bounds.bounds[1].y, - self.tile_bounds.bounds[0].y, - ) - return (*[ - x * (2**(zoom)) / self.tile_size - for x in [xstart, ystart, xend, yend] - ],) - - def _get_3857_image_params( - self, zoom: int, - bounds: TiledBounds) -> Tuple[float, float, float, float]: - """Computes the x and y tile bounds for fetching an image that - captures the entire labeling region (TiledData.bounds) given a specific zoom - """ - lat_start, lat_end = bounds.bounds[1].y, bounds.bounds[0].y - lng_start, lng_end = bounds.bounds[1].x, bounds.bounds[0].x - - # Convert to zoom 0 tile coordinates - xstart, ystart = self._latlng_to_tile(lat_start, lng_start) - xend, yend = self._latlng_to_tile(lat_end, lng_end) - - # Make sure that the tiles are increasing in order - xstart, xend = min(xstart, xend), max(xstart, xend) - ystart, yend = min(ystart, yend), max(ystart, yend) - return (*[pt * 2.0**zoom for pt in [xstart, ystart, xend, yend]],) - - def _latlng_to_tile(self, - lat: float, - lng: float, - zoom=0) -> Tuple[float, float]: - """Converts lat/lng to 3857 tile coordinates - Formula found here: - https://wiki.openstreetmap.org/wiki/Slippy_map_tilenames#lon.2Flat_to_tile_numbers_2 - """ - scale = 2**zoom - lat_rad = math.radians(lat) - x = (lng + 180.0) / 360.0 * scale - y = (1.0 - math.asinh(math.tan(lat_rad)) / math.pi) / 2.0 * scale - return x, y - - def _tile_to_pixel(self, tile: float) -> Tuple[int, int]: - """Rounds a tile coordinate and reports the remainder in pixels - """ - rounded_tile = int(tile) - remainder = tile - rounded_tile - pixel_offset = int(self.tile_size * remainder) - return rounded_tile, pixel_offset - - def _fetch_image_for_bounds(self, - x_tile_start: int, - y_tile_start: int, - x_tile_end: int, - y_tile_end: int, - zoom: int, - multithread=True) -> np.ndarray: - """Fetches the tiles and combines them into a single image. - - If a tile cannot be fetched, a padding of expected tile size is instead added. - """ - - if multithread: - tiles = {} - with ThreadPoolExecutor( - max_workers=TILE_DOWNLOAD_CONCURRENCY) as exc: - for x in range(x_tile_start, x_tile_end + 1): - for y in range(y_tile_start, y_tile_end + 1): - tiles[(x, y)] = exc.submit(self._fetch_tile, x, y, zoom) - - rows = [] - for y in range(y_tile_start, y_tile_end + 1): - row = [] - for x in range(x_tile_start, x_tile_end + 1): - try: - if multithread: - row.append(tiles[(x, y)].result()) - else: - row.append(self._fetch_tile(x, y, zoom)) - except: - row.append( - np.zeros(shape=(self.tile_size, self.tile_size, 3), - dtype=np.uint8)) - rows.append(np.hstack(row)) - - return np.vstack(rows) - - @retry.Retry(initial=1, maximum=16, multiplier=2) - def _fetch_tile(self, x: int, y: int, z: int) -> np.ndarray: - """ - Fetches the image and returns an np array. - """ - data = requests.get(self.tile_layer.url.format(x=x, y=y, z=z)) - data.raise_for_status() - decoded = np.array(Image.open(BytesIO(data.content)))[..., :3] - if decoded.shape[:2] != (self.tile_size, self.tile_size): - logger.warning(f"Unexpected tile size {decoded.shape}.") - return decoded - - def _crop_to_bounds( - self, - image: np.ndarray, - x_px_start: int, - y_px_start: int, - x_px_end: int, - y_px_end: int, - ) -> np.ndarray: - """This function slices off the excess pixels that are outside of the bounds. - This occurs because only full tiles can be downloaded at a time. - """ - - def invert_point(pt): - # Must have at least 1 pixel for stability. - pt = max(pt, 1) - # All pixel points are relative to a single tile - # So subtracting the tile size inverts the axis - pt = pt - self.tile_size - return pt if pt != 0 else None - - x_px_end, y_px_end = invert_point(x_px_end), invert_point(y_px_end) - return image[y_px_start:y_px_end, x_px_start:x_px_end, :] - - def _validate_num_tiles(self, xstart: float, ystart: float, xend: float, - yend: float, max_tiles: int): - """Calculates the number of expected tiles we would fetch. - - If this is greater than the number of max tiles, raise an error. - """ - total_n_tiles = (yend - ystart + 1) * (xend - xstart + 1) - if total_n_tiles > max_tiles: - raise ValueError(f"Requested zoom results in {total_n_tiles} tiles." - f"Max allowed tiles are {max_tiles}" - f"Increase max tiles or reduce zoom level.") - - @pydantic_compat.validator('zoom_levels') - def validate_zoom_levels(cls, zoom_levels): - if zoom_levels[0] > zoom_levels[1]: - raise ValueError( - f"Order of zoom levels should be min, max. Received {zoom_levels}" - ) - return zoom_levels - - -class EPSGTransformer(pydantic_compat.BaseModel): - """Transformer class between different EPSG's. Useful when wanting to project - in different formats. - """ - - class Config: - arbitrary_types_allowed = True - - transformer: Any - - @staticmethod - def _is_simple(epsg: EPSG) -> bool: - return epsg == EPSG.SIMPLEPIXEL - - @staticmethod - def _get_ranges(bounds: np.ndarray) -> Tuple[int, int]: - """helper function to get the range between bounds. - - returns a tuple (x_range, y_range)""" - x_range = np.max(bounds[:, 0]) - np.min(bounds[:, 0]) - y_range = np.max(bounds[:, 1]) - np.min(bounds[:, 1]) - return (x_range, y_range) - - @staticmethod - def _min_max_x_y(bounds: np.ndarray) -> Tuple[int, int, int, int]: - """returns the min x, max x, min y, max y of a numpy array - """ - return np.min(bounds[:, 0]), np.max(bounds[:, 0]), np.min( - bounds[:, 1]), np.max(bounds[:, 1]) - - @classmethod - def geo_and_pixel(cls, - src_epsg, - pixel_bounds: TiledBounds, - geo_bounds: TiledBounds, - zoom=0) -> Callable: - """method to change from one projection to simple projection""" - - pixel_bounds = pixel_bounds.bounds - geo_bounds_epsg = geo_bounds.epsg - geo_bounds = geo_bounds.bounds - - local_bounds = np.array([(point.x, point.y) for point in pixel_bounds], - dtype=int) - #convert geo bounds to pixel bounds. assumes geo bounds are in wgs84/EPS4326 per leaflet - global_bounds = np.array([ - PygeoPoint.from_latitude_longitude(latitude=point.y, - longitude=point.x).pixels(zoom) - for point in geo_bounds - ]) - - #get the range of pixels for both sets of bounds to use as a multiplification factor - local_x_range, local_y_range = cls._get_ranges(bounds=local_bounds) - global_x_range, global_y_range = cls._get_ranges(bounds=global_bounds) - - if src_epsg == EPSG.SIMPLEPIXEL: - - def transform(x: int, y: int) -> Callable[[int, int], Transformer]: - scaled_xy = (x * (global_x_range) / (local_x_range), - y * (global_y_range) / (local_y_range)) - - minx, _, miny, _ = cls._min_max_x_y(bounds=global_bounds) - x, y = map(lambda i, j: i + j, scaled_xy, (minx, miny)) - - point = PygeoPoint.from_pixel(pixel_x=x, pixel_y=y, - zoom=zoom).latitude_longitude - #convert to the desired epsg - return Transformer.from_crs(EPSG.EPSG4326.value, - geo_bounds_epsg.value, - always_xy=True).transform( - point[1], point[0]) - - return transform - - #handles 4326 from lat,lng - elif src_epsg == EPSG.EPSG4326: - - def transform(x: int, y: int) -> Callable[[int, int], Transformer]: - point_in_px = PygeoPoint.from_latitude_longitude( - latitude=y, longitude=x).pixels(zoom) - - minx, _, miny, _ = cls._min_max_x_y(global_bounds) - x, y = map(lambda i, j: i - j, point_in_px, (minx, miny)) - - return (x * (local_x_range) / (global_x_range), - y * (local_y_range) / (global_y_range)) - - return transform - - #handles 3857 from meters - elif src_epsg == EPSG.EPSG3857: - - def transform(x: int, y: int) -> Callable[[int, int], Transformer]: - point_in_px = PygeoPoint.from_meters(meter_y=y, - meter_x=x).pixels(zoom) - - minx, _, miny, _ = cls._min_max_x_y(global_bounds) - x, y = map(lambda i, j: i - j, point_in_px, (minx, miny)) - - return (x * (local_x_range) / (global_x_range), - y * (local_y_range) / (global_y_range)) - - return transform - - @classmethod - def create_geo_to_geo_transformer( - cls, src_epsg: EPSG, - tgt_epsg: EPSG) -> Callable[[int, int], Transformer]: - """method to change from one projection to another projection. - - supports EPSG transformations not Simple. - """ - if cls._is_simple(epsg=src_epsg) or cls._is_simple(epsg=tgt_epsg): - raise Exception( - f"Cannot be used for Simple transformations. Found {src_epsg} and {tgt_epsg}" - ) - - return EPSGTransformer(transformer=Transformer.from_crs( - src_epsg.value, tgt_epsg.value, always_xy=True).transform) - - @classmethod - def create_geo_to_pixel_transformer( - cls, - src_epsg, - pixel_bounds: TiledBounds, - geo_bounds: TiledBounds, - zoom=0) -> Callable[[int, int], Transformer]: - """method to change from a geo projection to Simple""" - - transform_function = cls.geo_and_pixel(src_epsg=src_epsg, - pixel_bounds=pixel_bounds, - geo_bounds=geo_bounds, - zoom=zoom) - return EPSGTransformer(transformer=transform_function) - - @classmethod - def create_pixel_to_geo_transformer( - cls, - src_epsg, - pixel_bounds: TiledBounds, - geo_bounds: TiledBounds, - zoom=0) -> Callable[[int, int], Transformer]: - """method to change from a geo projection to Simple""" - transform_function = cls.geo_and_pixel(src_epsg=src_epsg, - pixel_bounds=pixel_bounds, - geo_bounds=geo_bounds, - zoom=zoom) - return EPSGTransformer(transformer=transform_function) - - def _get_point_obj(self, point) -> Point: - point = self.transformer(point.x, point.y) - return Point(x=point[0], y=point[1]) - - def __call__( - self, shape: Union[Point, Line, Rectangle, Polygon] - ) -> Union[VectorTool, List[VectorTool]]: - if isinstance(shape, list): - return [self(geom) for geom in shape] - if isinstance(shape, Point): - return self._get_point_obj(shape) - if isinstance(shape, Line): - return Line(points=[self._get_point_obj(p) for p in shape.points]) - if isinstance(shape, Polygon): - return Polygon( - points=[self._get_point_obj(p) for p in shape.points]) - if isinstance(shape, Rectangle): - return Rectangle(start=self._get_point_obj(shape.start), - end=self._get_point_obj(shape.end)) - else: - raise ValueError(f"Unsupported type found: {type(shape)}") ----- -labelbox/data/annotation_types/data/base_data.py -from abc import ABC -from typing import Optional, Dict, List, Any - -from labelbox import pydantic_compat - - -class BaseData(pydantic_compat.BaseModel, ABC): - """ - Base class for objects representing data. - This class shouldn't directly be used - """ - external_id: Optional[str] = None - uid: Optional[str] = None - global_key: Optional[str] = None - media_attributes: Optional[Dict[str, Any]] = None - metadata: Optional[List[Dict[str, Any]]] = None - ----- -labelbox/data/annotation_types/data/text.py -from typing import Callable, Optional - -import requests -from requests.exceptions import ConnectTimeout -from google.api_core import retry - -from labelbox import pydantic_compat -from labelbox.exceptions import InternalServerError -from labelbox.typing_imports import Literal -from labelbox.utils import _NoCoercionMixin -from .base_data import BaseData - - -class TextData(BaseData, _NoCoercionMixin): - """ - Represents text data. Requires arg file_path, text, or url - - >>> TextData(text="") - - Args: - file_path (str) - text (str) - url (str) - """ - class_name: Literal["TextData"] = "TextData" - file_path: Optional[str] = None - text: Optional[str] = None - url: Optional[str] = None - - @property - def value(self) -> str: - """ - Property that unifies the data access pattern for all references to the text. - - Returns: - string representation of the text - """ - if self.text: - return self.text - elif self.file_path: - with open(self.file_path, "r") as file: - text = file.read() - self.text = text - return text - elif self.url: - text = self.fetch_remote() - self.text = text - return text - else: - raise ValueError("Must set either url, file_path or im_bytes") - - def set_fetch_fn(self, fn): - object.__setattr__(self, 'fetch_remote', lambda: fn(self)) - - @retry.Retry(deadline=15., - predicate=retry.if_exception_type(ConnectTimeout, - InternalServerError)) - def fetch_remote(self) -> str: - """ - Method for accessing url. - - If url is not publicly accessible or requires another access pattern - simply override this function - """ - response = requests.get(self.url) - if response.status_code in [500, 502, 503, 504]: - raise labelbox.exceptions.InternalServerError(response.text) - response.raise_for_status() - return response.text - - @retry.Retry(deadline=15.) - def create_url(self, signer: Callable[[bytes], str]) -> None: - """ - Utility for creating a url from any of the other text references. - - Args: - signer: A function that accepts bytes and returns a signed url. - Returns: - url for the text - """ - if self.url is not None: - return self.url - elif self.file_path is not None: - with open(self.file_path, 'rb') as file: - self.url = signer(file.read()) - elif self.text is not None: - self.url = signer(self.text.encode()) - else: - raise ValueError( - "One of url, im_bytes, file_path, numpy must not be None.") - return self.url - - @pydantic_compat.root_validator - def validate_date(cls, values): - file_path = values.get("file_path") - text = values.get("text") - url = values.get("url") - uid = values.get('uid') - global_key = values.get('global_key') - if uid == file_path == text == url == global_key == None: - raise ValueError( - "One of `file_path`, `text`, `uid`, `global_key` or `url` required." - ) - return values - - def __repr__(self) -> str: - return f"TextData(file_path={self.file_path}," \ - f"text={self.text[:30] + '...' if self.text is not None else None}," \ - f"url={self.url})" - - class config: - # Required for discriminating between data types - extra = 'forbid' - ----- -labelbox/data/annotation_types/data/document.py -from labelbox.typing_imports import Literal -from labelbox.utils import _NoCoercionMixin -from .base_data import BaseData - - -class DocumentData(BaseData, _NoCoercionMixin): - class_name: Literal["DocumentData"] = "DocumentData" ----- -labelbox/data/annotation_types/data/audio.py -from labelbox.typing_imports import Literal -from labelbox.utils import _NoCoercionMixin -from .base_data import BaseData - - -class AudioData(BaseData, _NoCoercionMixin): - class_name: Literal["AudioData"] = "AudioData" ----- -labelbox/data/annotation_types/data/raster.py -from abc import ABC -from io import BytesIO -from typing import Callable, Optional, Union -from typing_extensions import Literal - -from PIL import Image -from google.api_core import retry -from requests.exceptions import ConnectTimeout -import requests -import numpy as np - -from labelbox import pydantic_compat -from labelbox.exceptions import InternalServerError -from .base_data import BaseData -from ..types import TypedArray - - -class RasterData(pydantic_compat.BaseModel, ABC): - """Represents an image or segmentation mask. - """ - im_bytes: Optional[bytes] = None - file_path: Optional[str] = None - url: Optional[str] = None - arr: Optional[TypedArray[Literal['uint8']]] = None - - @classmethod - def from_2D_arr(cls, arr: Union[TypedArray[Literal['uint8']], - TypedArray[Literal['int']]], - **kwargs) -> "RasterData": - """Construct from a 2D numpy array - - Args: - arr: uint8 compatible numpy array - - Returns: - RasterData - """ - - if len(arr.shape) != 2: - raise ValueError( - f"Found array with shape {arr.shape}. Expected two dimensions [H, W]" - ) - - if not np.issubdtype(arr.dtype, np.integer): - raise ValueError("Array must be an integer subtype") - - if np.can_cast(arr, np.uint8): - arr = arr.astype(np.uint8) - else: - raise ValueError( - "Could not cast array to uint8, check that values are between 0 and 255" - ) - - arr = np.stack((arr,) * 3, axis=-1) - return cls(arr=arr, **kwargs) - - def bytes_to_np(self, image_bytes: bytes) -> np.ndarray: - """ - Converts image bytes to a numpy array - Args: - image_bytes (bytes): PNG encoded image - Returns: - numpy array representing the image - """ - arr = np.array(Image.open(BytesIO(image_bytes))) - if len(arr.shape) == 2: - arr = np.stack((arr,) * 3, axis=-1) - return arr[:, :, :3] - - def np_to_bytes(self, arr: np.ndarray) -> bytes: - """ - Converts a numpy array to bytes - Args: - arr (np.array): numpy array representing the image - Returns: - png encoded bytes - """ - if len(arr.shape) != 3: - raise ValueError( - "unsupported image format. Must be 3D ([H,W,C])." - f"Use {self.__class__.__name__}.from_2D_arr to construct from 2D" - ) - if arr.dtype != np.uint8: - raise TypeError(f"image data type must be uint8. Found {arr.dtype}") - - im_bytes = BytesIO() - Image.fromarray(arr).save(im_bytes, format="PNG") - return im_bytes.getvalue() - - @property - def value(self) -> np.ndarray: - """ - Property that unifies the data access pattern for all references to the raster. - - Returns: - numpy representation of the raster - """ - if self.arr is not None: - return self.arr - if self.im_bytes is not None: - return self.bytes_to_np(self.im_bytes) - elif self.file_path is not None: - with open(self.file_path, "rb") as img: - im_bytes = img.read() - self.im_bytes = im_bytes - arr = self.bytes_to_np(im_bytes) - return arr - elif self.url is not None: - im_bytes = self.fetch_remote() - self.im_bytes = im_bytes - return self.bytes_to_np(im_bytes) - else: - raise ValueError("Must set either url, file_path or im_bytes") - - def set_fetch_fn(self, fn): - object.__setattr__(self, 'fetch_remote', lambda: fn(self)) - - @retry.Retry(deadline=15., - predicate=retry.if_exception_type(ConnectTimeout, - InternalServerError)) - def fetch_remote(self) -> bytes: - """ - Method for accessing url. - - If url is not publicly accessible or requires another access pattern - simply override this function - """ - response = requests.get(self.url) - if response.status_code in [500, 502, 503, 504]: - raise InternalServerError(response.text) - response.raise_for_status() - return response.content - - @retry.Retry(deadline=30.) - def create_url(self, signer: Callable[[bytes], str]) -> str: - """ - Utility for creating a url from any of the other image representations. - - Args: - signer: A function that accepts bytes and returns a signed url. - Returns: - url for the raster data - """ - if self.url is not None: - return self.url - elif self.im_bytes is not None: - self.url = signer(self.im_bytes) - elif self.file_path is not None: - with open(self.file_path, 'rb') as file: - self.url = signer(file.read()) - elif self.arr is not None: - self.url = signer(self.np_to_bytes(self.arr)) - else: - raise ValueError( - "One of url, im_bytes, file_path, arr must not be None.") - return self.url - - @pydantic_compat.root_validator() - def validate_args(cls, values): - file_path = values.get("file_path") - im_bytes = values.get("im_bytes") - url = values.get("url") - arr = values.get("arr") - uid = values.get('uid') - global_key = values.get('global_key') - if uid == file_path == im_bytes == url == global_key == None and arr is None: - raise ValueError( - "One of `file_path`, `im_bytes`, `url`, `uid`, `global_key` or `arr` required." - ) - if arr is not None: - if arr.dtype != np.uint8: - raise TypeError( - "Numpy array representing segmentation mask must be np.uint8" - ) - elif len(arr.shape) != 3: - raise ValueError( - "unsupported image format. Must be 3D ([H,W,C])." - f"Use {cls.__name__}.from_2D_arr to construct from 2D") - return values - - def __repr__(self) -> str: - symbol_or_none = lambda data: '...' if data is not None else None - return f"{self.__class__.__name__}(im_bytes={symbol_or_none(self.im_bytes)}," \ - f"file_path={self.file_path}," \ - f"url={self.url}," \ - f"arr={symbol_or_none(self.arr)})" - - class Config: - # Required for sharing references - copy_on_model_validation = 'none' - # Required for discriminating between data types - extra = 'forbid' - - -class MaskData(RasterData): - """Used to represent a segmentation Mask - - All segments within a mask must be mutually exclusive. At a - single cell, only one class can be present. All Mask data is - converted to a [H,W,3] image. Classes are - - >>> # 3x3 mask with two classes and back ground - >>> MaskData.from_2D_arr([ - >>> [0, 0, 0], - >>> [1, 1, 1], - >>> [2, 2, 2], - >>>]) - - Args: - im_bytes: Optional[bytes] = None - file_path: Optional[str] = None - url: Optional[str] = None - arr: Optional[TypedArray[Literal['uint8']]] = None - """ - - -class ImageData(RasterData, BaseData): - ... - ----- -labelbox/data/annotation_types/data/dicom.py -from labelbox.typing_imports import Literal -from labelbox.utils import _NoCoercionMixin -from .base_data import BaseData - - -class DicomData(BaseData, _NoCoercionMixin): - class_name: Literal["DicomData"] = "DicomData" ----- -labelbox/data/annotation_types/data/video.py -import logging -import os -import urllib.request -from typing import Callable, Dict, Generator, Optional, Tuple -from typing_extensions import Literal -from uuid import uuid4 - -import cv2 -import numpy as np -from google.api_core import retry - -from .base_data import BaseData -from ..types import TypedArray - -from labelbox import pydantic_compat - -logger = logging.getLogger(__name__) - - -class VideoData(BaseData): - """ - Represents video - """ - file_path: Optional[str] = None - url: Optional[str] = None - frames: Optional[Dict[int, TypedArray[Literal['uint8']]]] = None - - def load_frames(self, overwrite: bool = False) -> None: - """ - Loads all frames into memory at once in order to access in non-sequential order. - This will use a lot of memory, especially for longer videos - - Args: - overwrite: Replace existing frames - """ - if self.frames and not overwrite: - return - - for count, frame in self.frame_generator(): - if self.frames is None: - self.frames = {} - self.frames[count] = frame - - @property - def value(self): - return self.frame_generator() - - def frame_generator( - self, - cache_frames=False, - download_dir='/tmp' - ) -> Generator[Tuple[int, np.ndarray], None, None]: - """ - A generator for accessing individual frames in a video. - - Args: - cache_frames (bool): Whether or not to cache frames while iterating through the video. - download_dir (str): Directory to save the video to. Defaults to `/tmp` dir - """ - if self.frames is not None: - for idx, frame in self.frames.items(): - yield idx, frame - return - elif self.url and not self.file_path: - file_path = os.path.join(download_dir, f"{uuid4()}.mp4") - logger.info("Downloading the video locally to %s", file_path) - self.fetch_remote(file_path) - self.file_path = file_path - - vidcap = cv2.VideoCapture(self.file_path) - - success, frame = vidcap.read() - count = 0 - if cache_frames: - self.frames = {} - while success: - frame = frame[:, :, ::-1] - yield count, frame - if cache_frames: - self.frames[count] = frame - success, frame = vidcap.read() - count += 1 - - def __getitem__(self, idx: int) -> np.ndarray: - if self.frames is None: - raise ValueError( - "Cannot select by index without iterating over the entire video or loading all frames." - ) - return self.frames[idx] - - def set_fetch_fn(self, fn): - object.__setattr__(self, 'fetch_remote', lambda: fn(self)) - - @retry.Retry(deadline=15.) - def fetch_remote(self, local_path) -> None: - """ - Method for downloading data from self.url - - If url is not publicly accessible or requires another access pattern - simply override this function - - Args: - local_path: Where to save the thing too. - """ - urllib.request.urlretrieve(self.url, local_path) - - @retry.Retry(deadline=15.) - def create_url(self, signer: Callable[[bytes], str]) -> None: - """ - Utility for creating a url from any of the other video references. - - Args: - signer: A function that accepts bytes and returns a signed url. - Returns: - url for the video - """ - if self.url is not None: - return self.url - elif self.file_path is not None: - with open(self.file_path, 'rb') as file: - self.url = signer(file.read()) - elif self.frames is not None: - self.file_path = self.frames_to_video(self.frames) - self.url = self.create_url(signer) - else: - raise ValueError("One of url, file_path, frames must not be None.") - return self.url - - def frames_to_video(self, - frames: Dict[int, np.ndarray], - fps=20, - save_dir='/tmp') -> str: - """ - Compresses the data by converting a set of individual frames to a single video. - - """ - file_path = os.path.join(save_dir, f"{uuid4()}.mp4") - out = None - for key in frames.keys(): - frame = frames[key] - if out is None: - out = cv2.VideoWriter(file_path, - cv2.VideoWriter_fourcc(*'MP4V'), fps, - frame.shape[:2]) - out.write(frame) - if out is None: - return - out.release() - return file_path - - @pydantic_compat.root_validator - def validate_data(cls, values): - file_path = values.get("file_path") - url = values.get("url") - frames = values.get("frames") - uid = values.get("uid") - global_key = values.get("global_key") - - if uid == file_path == frames == url == global_key == None: - raise ValueError( - "One of `file_path`, `frames`, `uid`, `global_key` or `url` required." - ) - return values - - def __repr__(self) -> str: - return f"VideoData(file_path={self.file_path}," \ - f"frames={'...' if self.frames is not None else None}," \ - f"url={self.url})" - - class Config: - # Required for discriminating between data types - extra = 'forbid' - ----- -labelbox/data/annotation_types/data/llm_prompt_creation.py -from labelbox.typing_imports import Literal -from labelbox.utils import _NoCoercionMixin -from .base_data import BaseData - - -class LlmPromptCreationData(BaseData, _NoCoercionMixin): - class_name: Literal["LlmPromptCreationData"] = "LlmPromptCreationData" ----- -labelbox/data/annotation_types/data/llm_prompt_response_creation.py -from labelbox.typing_imports import Literal -from labelbox.utils import _NoCoercionMixin -from .base_data import BaseData - - -class LlmPromptResponseCreationData(BaseData, _NoCoercionMixin): - class_name: Literal[ - "LlmPromptResponseCreationData"] = "LlmPromptResponseCreationData" ---END-- \ No newline at end of file