diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index 5ad843b89..a2fb09186 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -407,6 +407,7 @@ def upload_data(self, }), "map": (None, json.dumps({"1": ["variables.file"]})), } + response = requests.post( self.endpoint, headers={"authorization": "Bearer %s" % self.api_key}, diff --git a/libs/labelbox/src/labelbox/schema/dataset.py b/libs/labelbox/src/labelbox/schema/dataset.py index 0bc5b74b1..4f74e9b7f 100644 --- a/libs/labelbox/src/labelbox/schema/dataset.py +++ b/libs/labelbox/src/labelbox/schema/dataset.py @@ -29,27 +29,16 @@ from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params from labelbox.schema.export_task import ExportTask from labelbox.schema.identifiable import UniqueId, GlobalKey -from labelbox.schema.task import Task +from labelbox.schema.task import Task, DataUpsertTask from labelbox.schema.user import User from labelbox.schema.iam_integration import IAMIntegration +from labelbox.schema.internal.data_row_upsert_item import (DataRowUpsertItem) +from labelbox.schema.internal.data_row_uploader import DataRowUploader +from labelbox.schema.internal.datarow_upload_constants import ( + MAX_DATAROW_PER_API_OPERATION, FILE_UPLOAD_THREAD_COUNT, UPSERT_CHUNK_SIZE) logger = logging.getLogger(__name__) -MAX_DATAROW_PER_API_OPERATION = 150_000 - - -class DataRowUpsertItem(BaseModel): - id: dict - payload: dict - - def is_empty(self) -> bool: - """ - The payload is considered empty if it's actually empty or the only key is `dataset_id`. - :return: bool - """ - return (not self.payload or - len(self.payload.keys()) == 1 and "dataset_id" in self.payload) - class Dataset(DbObject, Updateable, Deletable): """ A Dataset is a collection of DataRows. @@ -64,7 +53,7 @@ class Dataset(DbObject, Updateable, Deletable): created_by (Relationship): `ToOne` relationship to User organization (Relationship): `ToOne` relationship to Organization """ - __upsert_chunk_size: Final = 10_000 + __upsert_chunk_size: Final = UPSERT_CHUNK_SIZE name = Field.String("name") description = Field.String("description") @@ -251,8 +240,10 @@ def create_data_rows_sync(self, items) -> None: f"Dataset.create_data_rows_sync() supports a max of {max_data_rows_supported} data rows." " For larger imports use the async function Dataset.create_data_rows()" ) - descriptor_url = self._create_descriptor_file( - items, max_attachments_per_data_row=max_attachments_per_data_row) + descriptor_url = DataRowUploader.create_descriptor_file( + self.client, + items, + max_attachments_per_data_row=max_attachments_per_data_row) dataset_param = "datasetId" url_param = "jsonUrl" query_str = """mutation AppendRowsToDatasetSyncPyApi($%s: ID!, $%s: String!){ @@ -264,13 +255,16 @@ def create_data_rows_sync(self, items) -> None: url_param: descriptor_url }) - def create_data_rows(self, items) -> "Task": + def create_data_rows(self, + items, + file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT + ) -> "DataUpsertTask": """ Asynchronously bulk upload data rows Use this instead of `Dataset.create_data_rows_sync` uploads for batches that contain more than 1000 data rows. Args: - items (iterable of (dict or str)): See the docstring for `Dataset._create_descriptor_file` for more information + items (iterable of (dict or str)): See the docstring for `DataRowUploader.create_descriptor_file` for more information Returns: Task representing the data import on the server side. The Task @@ -285,271 +279,44 @@ def create_data_rows(self, items) -> "Task": InvalidAttributeError: If there are fields in `items` not valid for a DataRow. ValueError: When the upload parameters are invalid - """ - descriptor_url = self._create_descriptor_file(items) - # Create data source - dataset_param = "datasetId" - url_param = "jsonUrl" - query_str = """mutation AppendRowsToDatasetPyApi($%s: ID!, $%s: String!){ - appendRowsToDataset(data:{datasetId: $%s, jsonFileUrl: $%s} - ){ taskId accepted errorMessage } } """ % (dataset_param, url_param, - dataset_param, url_param) - res = self.client.execute(query_str, { - dataset_param: self.uid, - url_param: descriptor_url - }) - res = res["appendRowsToDataset"] - if not res["accepted"]: - msg = res['errorMessage'] - raise InvalidQueryError( - f"Server did not accept DataRow creation request. {msg}") - - # Fetch and return the task. - task_id = res["taskId"] - user: User = self.client.get_user() - tasks: List[Task] = list( - user.created_tasks(where=Entity.Task.uid == task_id)) - # Cache user in a private variable as the relationship can't be - # resolved due to server-side limitations (see Task.created_by) - # for more info. - if len(tasks) != 1: - raise ResourceNotFoundError(Entity.Task, task_id) - task: Task = tasks[0] - task._user = user - return task - - def _create_descriptor_file(self, - items, - max_attachments_per_data_row=None, - is_upsert=False): + NOTE dicts and strings items can not be mixed in the same call. It is a responsibility of the caller to ensure that all items are of the same type. """ - This function is shared by both `Dataset.create_data_rows` and `Dataset.create_data_rows_sync` - to prepare the input file. The user defined input is validated, processed, and json stringified. - Finally the json data is uploaded to gcs and a uri is returned. This uri can be passed to - + if file_upload_thread_count < 1: + raise ValueError( + "file_upload_thread_count must be a positive integer") - Each element in `items` can be either a `str` or a `dict`. If - it is a `str`, then it is interpreted as a local file path. The file - is uploaded to Labelbox and a DataRow referencing it is created. - - If an item is a `dict`, then it could support one of the two following structures - 1. For static imagery, video, and text it should map `DataRow` field names to values. - At the minimum an `item` passed as a `dict` must contain a `row_data` key and value. - If the value for row_data is a local file path and the path exists, - then the local file will be uploaded to labelbox. - - 2. For tiled imagery the dict must match the import structure specified in the link below - https://docs.labelbox.com/data-model/en/index-en#tiled-imagery-import - - >>> dataset.create_data_rows([ - >>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"}, - >>> {DataRow.row_data:"/path/to/file1.jpg"}, - >>> "path/to/file2.jpg", - >>> {DataRow.row_data: {"tileLayerUrl" : "http://", ...}} - >>> {DataRow.row_data: {"type" : ..., 'version' : ..., 'messages' : [...]}} - >>> ]) - - For an example showing how to upload tiled data_rows see the following notebook: - https://github.com/Labelbox/labelbox-python/blob/ms/develop/model_assisted_labeling/tiled_imagery_mal.ipynb - - Args: - items (iterable of (dict or str)): See above for details. - max_attachments_per_data_row (Optional[int]): Param used during attachment validation to determine - if the user has provided too many attachments. - - Returns: - uri (string): A reference to the uploaded json data. - - Raises: - InvalidQueryError: If the `items` parameter does not conform to - the specification above or if the server did not accept the - DataRow creation request (unknown reason). - InvalidAttributeError: If there are fields in `items` not valid for - a DataRow. - ValueError: When the upload parameters are invalid - """ - file_upload_thread_count = 20 - DataRow = Entity.DataRow - AssetAttachment = Entity.AssetAttachment - - def upload_if_necessary(item): - if is_upsert and 'row_data' not in item: - # When upserting, row_data is not required - return item - row_data = item['row_data'] - if isinstance(row_data, str) and os.path.exists(row_data): - item_url = self.client.upload_file(row_data) - item['row_data'] = item_url - if 'external_id' not in item: - # Default `external_id` to local file name - item['external_id'] = row_data - return item - - def validate_attachments(item): - attachments = item.get('attachments') - if attachments: - if isinstance(attachments, list): - if max_attachments_per_data_row and len( - attachments) > max_attachments_per_data_row: - raise ValueError( - f"Max attachments number of supported attachments per data row is {max_attachments_per_data_row}." - f" Found {len(attachments)}. Condense multiple attachments into one with the HTML attachment type if necessary." - ) - for attachment in attachments: - AssetAttachment.validate_attachment_json(attachment) - else: - raise ValueError( - f"Attachments must be a list. Found {type(attachments)}" - ) - return attachments - - def validate_embeddings(item): - embeddings = item.get("embeddings") - if embeddings: - item["embeddings"] = [ - EmbeddingVector(**e).to_gql() for e in embeddings - ] - - def validate_conversational_data(conversational_data: list) -> None: - """ - Checks each conversational message for keys expected as per https://docs.labelbox.com/reference/text-conversational#sample-conversational-json - - Args: - conversational_data (list): list of dictionaries. - """ + string_items = [item for item in items if isinstance(item, str)] + dict_items = [item for item in items if isinstance(item, dict)] + dict_string_items = [] - def check_message_keys(message): - accepted_message_keys = set([ - "messageId", "timestampUsec", "content", "user", "align", - "canLabel" - ]) - for key in message.keys(): - if not key in accepted_message_keys: - raise KeyError( - f"Invalid {key} key found! Accepted keys in messages list is {accepted_message_keys}" - ) - - if conversational_data and not isinstance(conversational_data, - list): - raise ValueError( - f"conversationalData must be a list. Found {type(conversational_data)}" - ) + if len(string_items) > 0: + dict_string_items = self._build_from_local_paths(string_items) + specs = DataRowUpsertItem.build(self.uid, + dict_items + dict_string_items) + return self._exec_upsert_data_rows(specs, file_upload_thread_count) - [check_message_keys(message) for message in conversational_data] - - def parse_metadata_fields(item): - metadata_fields = item.get('metadata_fields') - if metadata_fields: - mdo = self.client.get_data_row_metadata_ontology() - item['metadata_fields'] = mdo.parse_upsert_metadata( - metadata_fields) - - def format_row(item): - # Formats user input into a consistent dict structure - if isinstance(item, dict): - # Convert fields to strings - item = { - key.name if isinstance(key, Field) else key: value - for key, value in item.items() - } - elif isinstance(item, str): - # The main advantage of using a string over a dict is that the user is specifying - # that the file should exist locally. - # That info is lost after this section so we should check for it here. - if not os.path.exists(item): - raise ValueError(f"Filepath {item} does not exist.") - item = {"row_data": item, "external_id": item} - return item - - def validate_keys(item): - if not is_upsert and 'row_data' not in item: - raise InvalidQueryError( - "`row_data` missing when creating DataRow.") - - if isinstance(item.get('row_data'), - str) and item.get('row_data').startswith("s3:/"): - raise InvalidQueryError( - "row_data: s3 assets must start with 'https'.") - allowed_extra_fields = { - 'attachments', 'media_type', 'dataset_id', 'embeddings' - } - invalid_keys = set(item) - {f.name for f in DataRow.fields() - } - allowed_extra_fields - if invalid_keys: - raise InvalidAttributeError(DataRow, invalid_keys) - return item - - def formatLegacyConversationalData(item): - messages = item.pop("conversationalData") - version = item.pop("version", 1) - type = item.pop("type", "application/vnd.labelbox.conversational") - if "externalId" in item: - external_id = item.pop("externalId") - item["external_id"] = external_id - if "globalKey" in item: - global_key = item.pop("globalKey") - item["globalKey"] = global_key - validate_conversational_data(messages) - one_conversation = \ - { - "type": type, - "version": version, - "messages": messages - } - item["row_data"] = one_conversation - return item - - def convert_item(data_row_item): - if isinstance(data_row_item, DataRowUpsertItem): - item = data_row_item.payload - else: - item = data_row_item - - if "tileLayerUrl" in item: - validate_attachments(item) - return item - - if "conversationalData" in item: - formatLegacyConversationalData(item) - - # Convert all payload variations into the same dict format - item = format_row(item) - # Make sure required keys exist (and there are no extra keys) - validate_keys(item) - # Make sure attachments are valid - validate_attachments(item) - # Make sure embeddings are valid - validate_embeddings(item) - # Parse metadata fields if they exist - parse_metadata_fields(item) - # Upload any local file paths - item = upload_if_necessary(item) - - if isinstance(data_row_item, DataRowUpsertItem): - return {'id': data_row_item.id, 'payload': item} - else: - return item - - if not isinstance(items, Iterable): - raise ValueError( - f"Must pass an iterable to create_data_rows. Found {type(items)}" - ) + def _build_from_local_paths( + self, + items: List[str], + file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT) -> List[dict]: + uploaded_items = [] - if len(items) > MAX_DATAROW_PER_API_OPERATION: - raise MalformedQueryException( - f"Cannot create more than {MAX_DATAROW_PER_API_OPERATION} DataRows per function call." - ) + def upload_file(item): + item_url = self.client.upload_file(item) + return {'row_data': item_url, 'external_id': item} with ThreadPoolExecutor(file_upload_thread_count) as executor: - futures = [executor.submit(convert_item, item) for item in items] - items = [future.result() for future in as_completed(futures)] - # Prepare and upload the desciptor file - data = json.dumps(items) - return self.client.upload_data(data, - content_type="application/json", - filename="json_import.json") + futures = [ + executor.submit(upload_file, item) + for item in items + if isinstance(item, str) and os.path.exists(item) + ] + more_items = [future.result() for future in as_completed(futures)] + uploaded_items.extend(more_items) + + return uploaded_items def data_rows_for_external_id(self, external_id, @@ -809,7 +576,10 @@ def _export( is_streamable = res["isStreamable"] return Task.get_task(self.client, task_id), is_streamable - def upsert_data_rows(self, items, file_upload_thread_count=20) -> "Task": + def upsert_data_rows(self, + items, + file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT + ) -> "DataUpsertTask": """ Upserts data rows in this dataset. When "key" is provided, and it references an existing data row, an update will be performed. When "key" is not provided a new data row will be created. @@ -840,40 +610,22 @@ def upsert_data_rows(self, items, file_upload_thread_count=20) -> "Task": >>> ]) >>> task.wait_till_done() """ - if len(items) > MAX_DATAROW_PER_API_OPERATION: - raise MalformedQueryException( - f"Cannot upsert more than {MAX_DATAROW_PER_API_OPERATION} DataRows per function call." - ) - - specs = self._convert_items_to_upsert_format(items) - - empty_specs = list(filter(lambda spec: spec.is_empty(), specs)) + specs = DataRowUpsertItem.build(self.uid, items, (UniqueId, GlobalKey)) + return self._exec_upsert_data_rows(specs, file_upload_thread_count) - if empty_specs: - ids = list(map(lambda spec: spec.id.get("value"), empty_specs)) - raise ValueError( - f"The following items have an empty payload: {ids}") - - chunks = [ - specs[i:i + self.__upsert_chunk_size] - for i in range(0, len(specs), self.__upsert_chunk_size) - ] - - def _upload_chunk(_chunk): - return self._create_descriptor_file(_chunk, is_upsert=True) + def _exec_upsert_data_rows( + self, + specs: List[DataRowUpsertItem], + file_upload_thread_count: int = FILE_UPLOAD_THREAD_COUNT + ) -> "DataUpsertTask": - with ThreadPoolExecutor(file_upload_thread_count) as executor: - futures = [ - executor.submit(_upload_chunk, chunk) for chunk in chunks - ] - chunk_uris = [future.result() for future in as_completed(futures)] + manifest = DataRowUploader.upload_in_chunks( + client=self.client, + specs=specs, + file_upload_thread_count=file_upload_thread_count, + upsert_chunk_size=UPSERT_CHUNK_SIZE) - manifest = { - "source": "SDK", - "item_count": len(specs), - "chunk_uris": chunk_uris - } - data = json.dumps(manifest).encode("utf-8") + data = json.dumps(manifest.dict()).encode("utf-8") manifest_uri = self.client.upload_data(data, content_type="application/json", filename="manifest.json") @@ -888,37 +640,17 @@ def _upload_chunk(_chunk): res = self.client.execute(query_str, {"manifestUri": manifest_uri}) res = res["upsertDataRows"] - task = Task(self.client, res) + task = DataUpsertTask(self.client, res) task._user = self.client.get_user() return task - def _convert_items_to_upsert_format(self, _items): - _upsert_items: List[DataRowUpsertItem] = [] - for item in _items: - # enforce current dataset's id for all specs - item['dataset_id'] = self.uid - key = item.pop('key', None) - if not key: - key = {'type': 'AUTO', 'value': ''} - elif isinstance(key, UniqueId): - key = {'type': 'ID', 'value': key.key} - elif isinstance(key, GlobalKey): - key = {'type': 'GKEY', 'value': key.key} - else: - raise ValueError( - f"Key must be an instance of UniqueId or GlobalKey, got: {type(item['key']).__name__}" - ) - item = { - k: v for k, v in item.items() if v is not None - } # remove None values - _upsert_items.append(DataRowUpsertItem(payload=item, id=key)) - return _upsert_items - - def add_iam_integration(self, iam_integration: Union[str, IAMIntegration]) -> IAMIntegration: + def add_iam_integration( + self, iam_integration: Union[str, + IAMIntegration]) -> IAMIntegration: """ Sets the IAM integration for the dataset. IAM integration is used to sign URLs for data row assets. - Args: + Args: iam_integration (Union[str, IAMIntegration]): IAM integration object or IAM integration id. Returns: @@ -949,7 +681,8 @@ def add_iam_integration(self, iam_integration: Union[str, IAMIntegration]) -> IA >>> dataset.set_iam_integration(iam_integration) """ - iam_integration_id = iam_integration.uid if isinstance(iam_integration, IAMIntegration) else iam_integration + iam_integration_id = iam_integration.uid if isinstance( + iam_integration, IAMIntegration) else iam_integration query = """ mutation SetSignerForDatasetPyApi($signerId: ID!, $datasetId: ID!) { @@ -965,20 +698,30 @@ def add_iam_integration(self, iam_integration: Union[str, IAMIntegration]) -> IA } """ - response = self.client.execute(query, {"signerId": iam_integration_id, "datasetId": self.uid}) + response = self.client.execute(query, { + "signerId": iam_integration_id, + "datasetId": self.uid + }) if not response: - raise ResourceNotFoundError(IAMIntegration, {"signerId": iam_integration_id, "datasetId": self.uid}) - - try: - iam_integration_id = response.get("setSignerForDataset", {}).get("signer", {})["id"] + raise ResourceNotFoundError(IAMIntegration, { + "signerId": iam_integration_id, + "datasetId": self.uid + }) - return [integration for integration - in self.client.get_organization().get_iam_integrations() - if integration.uid == iam_integration_id][0] + try: + iam_integration_id = response.get("setSignerForDataset", + {}).get("signer", {})["id"] + + return [ + integration for integration in + self.client.get_organization().get_iam_integrations() + if integration.uid == iam_integration_id + ][0] except: - raise LabelboxError(f"Can't retrieve IAM integration {iam_integration_id}") - + raise LabelboxError( + f"Can't retrieve IAM integration {iam_integration_id}") + def remove_iam_integration(self) -> None: """ Unsets the IAM integration for the dataset. @@ -1008,4 +751,3 @@ def remove_iam_integration(self) -> None: if not response: raise ResourceNotFoundError(Dataset, {"id": self.uid}) - \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py b/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py new file mode 100644 index 000000000..9be4e2ffd --- /dev/null +++ b/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py @@ -0,0 +1,287 @@ +import json +import os +from concurrent.futures import ThreadPoolExecutor, as_completed + +from typing import Iterable, List + +from labelbox.exceptions import InvalidQueryError +from labelbox.exceptions import InvalidAttributeError +from labelbox.exceptions import MalformedQueryException +from labelbox.orm.model import Entity +from labelbox.orm.model import Field +from labelbox.schema.embedding import EmbeddingVector +from labelbox.pydantic_compat import BaseModel +from labelbox.schema.internal.datarow_upload_constants import ( + MAX_DATAROW_PER_API_OPERATION, FILE_UPLOAD_THREAD_COUNT) +from labelbox.schema.internal.data_row_upsert_item import DataRowUpsertItem + + +class UploadManifest(BaseModel): + source: str + item_count: int + chunk_uris: List[str] + + +class DataRowUploader: + + @staticmethod + def create_descriptor_file(client, + items, + max_attachments_per_data_row=None, + is_upsert=False): + """ + This function is shared by `Dataset.create_data_rows`, `Dataset.create_data_rows_sync` and `Dataset.update_data_rows`. + It is used to prepare the input file. The user defined input is validated, processed, and json stringified. + Finally the json data is uploaded to gcs and a uri is returned. This uri can be passed as a parameter to a mutation that uploads data rows + + Each element in `items` can be either a `str` or a `dict`. If + it is a `str`, then it is interpreted as a local file path. The file + is uploaded to Labelbox and a DataRow referencing it is created. + + If an item is a `dict`, then it could support one of the two following structures + 1. For static imagery, video, and text it should map `DataRow` field names to values. + At the minimum an `items` passed as a `dict` must contain a `row_data` key and value. + If the value for row_data is a local file path and the path exists, + then the local file will be uploaded to labelbox. + + 2. For tiled imagery the dict must match the import structure specified in the link below + https://docs.labelbox.com/data-model/en/index-en#tiled-imagery-import + + >>> dataset.create_data_rows([ + >>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"}, + >>> {DataRow.row_data:"/path/to/file1.jpg"}, + >>> "path/to/file2.jpg", + >>> {DataRow.row_data: {"tileLayerUrl" : "http://", ...}} + >>> {DataRow.row_data: {"type" : ..., 'version' : ..., 'messages' : [...]}} + >>> ]) + + For an example showing how to upload tiled data_rows see the following notebook: + https://github.com/Labelbox/labelbox-python/blob/ms/develop/model_assisted_labeling/tiled_imagery_mal.ipynb + + Args: + items (iterable of (dict or str)): See above for details. + max_attachments_per_data_row (Optional[int]): Param used during attachment validation to determine + if the user has provided too many attachments. + + Returns: + uri (string): A reference to the uploaded json data. + + Raises: + InvalidQueryError: If the `items` parameter does not conform to + the specification above or if the server did not accept the + DataRow creation request (unknown reason). + InvalidAttributeError: If there are fields in `items` not valid for + a DataRow. + ValueError: When the upload parameters are invalid + """ + file_upload_thread_count = FILE_UPLOAD_THREAD_COUNT + DataRow = Entity.DataRow + AssetAttachment = Entity.AssetAttachment + + def upload_if_necessary(item): + if is_upsert and 'row_data' not in item: + # When upserting, row_data is not required + return item + row_data = item['row_data'] + if isinstance(row_data, str) and os.path.exists(row_data): + item_url = client.upload_file(row_data) + item['row_data'] = item_url + if 'external_id' not in item: + # Default `external_id` to local file name + item['external_id'] = row_data + return item + + def validate_attachments(item): + attachments = item.get('attachments') + if attachments: + if isinstance(attachments, list): + if max_attachments_per_data_row and len( + attachments) > max_attachments_per_data_row: + raise ValueError( + f"Max attachments number of supported attachments per data row is {max_attachments_per_data_row}." + f" Found {len(attachments)}. Condense multiple attachments into one with the HTML attachment type if necessary." + ) + for attachment in attachments: + AssetAttachment.validate_attachment_json(attachment) + else: + raise ValueError( + f"Attachments must be a list. Found {type(attachments)}" + ) + return attachments + + def validate_embeddings(item): + embeddings = item.get("embeddings") + if embeddings: + item["embeddings"] = [ + EmbeddingVector(**e).to_gql() for e in embeddings + ] + + def validate_conversational_data(conversational_data: list) -> None: + """ + Checks each conversational message for keys expected as per https://docs.labelbox.com/reference/text-conversational#sample-conversational-json + + Args: + conversational_data (list): list of dictionaries. + """ + + def check_message_keys(message): + accepted_message_keys = set([ + "messageId", "timestampUsec", "content", "user", "align", + "canLabel" + ]) + for key in message.keys(): + if not key in accepted_message_keys: + raise KeyError( + f"Invalid {key} key found! Accepted keys in messages list is {accepted_message_keys}" + ) + + if conversational_data and not isinstance(conversational_data, + list): + raise ValueError( + f"conversationalData must be a list. Found {type(conversational_data)}" + ) + + [check_message_keys(message) for message in conversational_data] + + def parse_metadata_fields(item): + metadata_fields = item.get('metadata_fields') + if metadata_fields: + mdo = client.get_data_row_metadata_ontology() + item['metadata_fields'] = mdo.parse_upsert_metadata( + metadata_fields) + + def format_row(item): + # Formats user input into a consistent dict structure + if isinstance(item, dict): + # Convert fields to strings + item = { + key.name if isinstance(key, Field) else key: value + for key, value in item.items() + } + elif isinstance(item, str): + # The main advantage of using a string over a dict is that the user is specifying + # that the file should exist locally. + # That info is lost after this section so we should check for it here. + if not os.path.exists(item): + raise ValueError(f"Filepath {item} does not exist.") + item = {"row_data": item, "external_id": item} + return item + + def validate_keys(item): + if not is_upsert and 'row_data' not in item: + raise InvalidQueryError( + "`row_data` missing when creating DataRow.") + + if isinstance(item.get('row_data'), + str) and item.get('row_data').startswith("s3:/"): + raise InvalidQueryError( + "row_data: s3 assets must start with 'https'.") + allowed_extra_fields = { + 'attachments', 'media_type', 'dataset_id', 'embeddings' + } + invalid_keys = set(item) - {f.name for f in DataRow.fields() + } - allowed_extra_fields + if invalid_keys: + raise InvalidAttributeError(DataRow, invalid_keys) + return item + + def format_legacy_conversational_data(item): + messages = item.pop("conversationalData") + version = item.pop("version", 1) + type = item.pop("type", "application/vnd.labelbox.conversational") + if "externalId" in item: + external_id = item.pop("externalId") + item["external_id"] = external_id + if "globalKey" in item: + global_key = item.pop("globalKey") + item["globalKey"] = global_key + validate_conversational_data(messages) + one_conversation = \ + { + "type": type, + "version": version, + "messages": messages + } + item["row_data"] = one_conversation + return item + + def convert_item(data_row_item): + if isinstance(data_row_item, DataRowUpsertItem): + item = data_row_item.payload + else: + item = data_row_item + + if "tileLayerUrl" in item: + validate_attachments(item) + return item + + if "conversationalData" in item: + format_legacy_conversational_data(item) + + # Convert all payload variations into the same dict format + item = format_row(item) + # Make sure required keys exist (and there are no extra keys) + validate_keys(item) + # Make sure attachments are valid + validate_attachments(item) + # Make sure embeddings are valid + validate_embeddings(item) + # Parse metadata fields if they exist + parse_metadata_fields(item) + # Upload any local file paths + item = upload_if_necessary(item) + + if isinstance(data_row_item, DataRowUpsertItem): + return {'id': data_row_item.id, 'payload': item} + else: + return item + + if not isinstance(items, Iterable): + raise ValueError( + f"Must pass an iterable to create_data_rows. Found {type(items)}" + ) + + if len(items) > MAX_DATAROW_PER_API_OPERATION: + raise MalformedQueryException( + f"Cannot create more than {MAX_DATAROW_PER_API_OPERATION} DataRows per function call." + ) + + with ThreadPoolExecutor(file_upload_thread_count) as executor: + futures = [executor.submit(convert_item, item) for item in items] + items = [future.result() for future in as_completed(futures)] + # Prepare and upload the desciptor file + data = json.dumps(items) + return client.upload_data(data, + content_type="application/json", + filename="json_import.json") + + @staticmethod + def upload_in_chunks(client, specs: List[DataRowUpsertItem], + file_upload_thread_count: int, + upsert_chunk_size: int) -> UploadManifest: + empty_specs = list(filter(lambda spec: spec.is_empty(), specs)) + + if empty_specs: + ids = list(map(lambda spec: spec.id.get("value"), empty_specs)) + raise ValueError( + f"The following items have an empty payload: {ids}") + + chunks = [ + specs[i:i + upsert_chunk_size] + for i in range(0, len(specs), upsert_chunk_size) + ] + + def _upload_chunk(chunk): + return DataRowUploader.create_descriptor_file(client, + chunk, + is_upsert=True) + + with ThreadPoolExecutor(file_upload_thread_count) as executor: + futures = [ + executor.submit(_upload_chunk, chunk) for chunk in chunks + ] + chunk_uris = [future.result() for future in as_completed(futures)] + + return UploadManifest(source="SDK", + item_count=len(specs), + chunk_uris=chunk_uris) diff --git a/libs/labelbox/src/labelbox/schema/internal/data_row_upsert_item.py b/libs/labelbox/src/labelbox/schema/internal/data_row_upsert_item.py new file mode 100644 index 000000000..e2c0cb2b5 --- /dev/null +++ b/libs/labelbox/src/labelbox/schema/internal/data_row_upsert_item.py @@ -0,0 +1,50 @@ +from typing import List, Tuple, Optional + +from labelbox.schema.identifiable import UniqueId, GlobalKey +from labelbox.pydantic_compat import BaseModel + + +class DataRowUpsertItem(BaseModel): + """ + Base class for creating payloads for upsert operations. + """ + id: dict + payload: dict + + @classmethod + def build( + cls, + dataset_id: str, + items: List[dict], + key_types: Optional[Tuple[type, ...]] = () + ) -> List["DataRowUpsertItem"]: + upload_items = [] + + for item in items: + # enforce current dataset's id for all specs + item['dataset_id'] = dataset_id + key = item.pop('key', None) + if not key: + key = {'type': 'AUTO', 'value': ''} + elif isinstance(key, key_types): # type: ignore + key = {'type': key.id_type.value, 'value': key.key} + else: + if not key_types: + raise ValueError( + f"Can not have a key for this item, got: {key}") + raise ValueError( + f"Key must be an instance of {', '.join([t.__name__ for t in key_types])}, got: {type(item['key']).__name__}" + ) + item = { + k: v for k, v in item.items() if v is not None + } # remove None values + upload_items.append(cls(payload=item, id=key)) + return upload_items + + def is_empty(self) -> bool: + """ + The payload is considered empty if it's actually empty or the only key is `dataset_id`. + :return: bool + """ + return (not self.payload or + len(self.payload.keys()) == 1 and "dataset_id" in self.payload) diff --git a/libs/labelbox/src/labelbox/schema/internal/datarow_upload_constants.py b/libs/labelbox/src/labelbox/schema/internal/datarow_upload_constants.py new file mode 100644 index 000000000..f4c919095 --- /dev/null +++ b/libs/labelbox/src/labelbox/schema/internal/datarow_upload_constants.py @@ -0,0 +1,4 @@ +MAX_DATAROW_PER_API_OPERATION = 150_000 +FILE_UPLOAD_THREAD_COUNT = 20 +UPSERT_CHUNK_SIZE = 10_000 +DOWNLOAD_RESULT_PAGE_SIZE = 5_000 diff --git a/libs/labelbox/src/labelbox/schema/task.py b/libs/labelbox/src/labelbox/schema/task.py index 977f7d4fa..83dea3d1f 100644 --- a/libs/labelbox/src/labelbox/schema/task.py +++ b/libs/labelbox/src/labelbox/schema/task.py @@ -2,13 +2,19 @@ import logging import requests import time -from typing import TYPE_CHECKING, Callable, Optional, Dict, Any, List, Union +from typing import TYPE_CHECKING, Callable, Optional, Dict, Any, List, Union, Final from labelbox import parser from labelbox.exceptions import ResourceNotFoundError from labelbox.orm.db_object import DbObject from labelbox.orm.model import Field, Relationship, Entity +from labelbox.pagination import PaginatedCollection +from labelbox.schema.internal.datarow_upload_constants import ( + MAX_DATAROW_PER_API_OPERATION, + DOWNLOAD_RESULT_PAGE_SIZE, +) + if TYPE_CHECKING: from labelbox import User @@ -49,6 +55,13 @@ class Task(DbObject): created_by = Relationship.ToOne("User", False, "created_by") organization = Relationship.ToOne("Organization") + def __eq__(self, task): + return isinstance( + task, Task) and task.uid == self.uid and task.type == self.type + + def __hash__(self): + return hash(self.uid) + # Import and upsert have several instances of special casing def is_creation_task(self) -> bool: return self.name == 'JSON Import' or self.type == 'adv-upsert-data-rows' @@ -214,3 +227,169 @@ def get_task(client, task_id): task: Task = tasks[0] task._user = user return task + + +class DataUpsertTask(Task): + """ + Task class for data row upsert operations + """ + MAX_DOWNLOAD_SIZE: Final = MAX_DATAROW_PER_API_OPERATION + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._user = None + + @property + def result(self) -> Optional[List[Dict[str, Any]]]: # type: ignore + """ + Fetches all results. + Note, for large uploads (>150K data rows), it could take multiple minutes to complete + """ + if self.status == "FAILED": + raise ValueError(f"Job failed. Errors : {self.errors}") + return self._results_as_list() + + @property + def errors(self) -> Optional[List[Dict[str, Any]]]: # type: ignore + """ + Fetches all errors. + Note, for large uploads / large number of errors (>150K), it could take multiple minutes to complete + """ + return self._errors_as_list() + + @property + def created_data_rows( # type: ignore + self) -> Optional[List[Dict[str, Any]]]: + return self.result + + @property + def failed_data_rows( # type: ignore + self) -> Optional[List[Dict[str, Any]]]: + return self.errors + + def _download_results_paginated(self) -> PaginatedCollection: + page_size = DOWNLOAD_RESULT_PAGE_SIZE + from_cursor = None + + query_str = """query SuccessesfulDataRowImportsPyApi($taskId: ID!, $first: Int, $from: String) { + successesfulDataRowImports(data: { taskId: $taskId, first: $first, from: $from}) + { + nodes { + id + externalId + globalKey + rowData + } + after + total + } + } + """ + + params = { + 'taskId': self.uid, + 'first': page_size, + 'from': from_cursor, + } + + return PaginatedCollection( + client=self.client, + query=query_str, + params=params, + dereferencing=['successesfulDataRowImports', 'nodes'], + obj_class=lambda _, data_row: { + 'id': data_row.get('id'), + 'external_id': data_row.get('externalId'), + 'row_data': data_row.get('rowData'), + 'global_key': data_row.get('globalKey'), + }, + cursor_path=['successesfulDataRowImports', 'after'], + ) + + def _download_errors_paginated(self) -> PaginatedCollection: + page_size = DOWNLOAD_RESULT_PAGE_SIZE # hardcode to avoid overloading the server + from_cursor = None + + query_str = """query FailedDataRowImportsPyApi($taskId: ID!, $first: Int, $from: String) { + failedDataRowImports(data: { taskId: $taskId, first: $first, from: $from}) + { + after + total + results { + message + spec { + externalId + globalKey + rowData + metadata { + schemaId + value + name + } + attachments { + type + value + name + } + } + } + } + } + """ + + params = { + 'taskId': self.uid, + 'first': page_size, + 'from': from_cursor, + } + + def convert_errors_to_legacy_format(client, data_row): + spec = data_row.get('spec', {}) + return { + 'message': + data_row.get('message'), + 'failedDataRows': [{ + 'externalId': spec.get('externalId'), + 'rowData': spec.get('rowData'), + 'globalKey': spec.get('globalKey'), + 'metadata': spec.get('metadata', []), + 'attachments': spec.get('attachments', []), + }] + } + + return PaginatedCollection( + client=self.client, + query=query_str, + params=params, + dereferencing=['failedDataRowImports', 'results'], + obj_class=convert_errors_to_legacy_format, + cursor_path=['failedDataRowImports', 'after'], + ) + + def _results_as_list(self) -> Optional[List[Dict[str, Any]]]: + total_downloaded = 0 + results = [] + data = self._download_results_paginated() + + for row in data: + results.append(row) + total_downloaded += 1 + + if len(results) == 0: + return None + + return results + + def _errors_as_list(self) -> Optional[List[Dict[str, Any]]]: + total_downloaded = 0 + errors = [] + data = self._download_errors_paginated() + + for row in data: + errors.append(row) + total_downloaded += 1 + + if len(errors) == 0: + return None + + return errors diff --git a/libs/labelbox/tests/data/annotation_import/test_model_run.py b/libs/labelbox/tests/data/annotation_import/test_model_run.py index 328b38ba5..3b2e04e62 100644 --- a/libs/labelbox/tests/data/annotation_import/test_model_run.py +++ b/libs/labelbox/tests/data/annotation_import/test_model_run.py @@ -174,6 +174,7 @@ def test_model_run_split_assignment_by_data_row_ids( data_rows = dataset.create_data_rows([{ "row_data": image_url } for _ in range(n_data_rows)]) + data_rows.wait_till_done() data_row_ids = [data_row['id'] for data_row in data_rows.result] configured_project_with_one_data_row._wait_until_data_rows_are_processed( data_row_ids=data_row_ids) diff --git a/libs/labelbox/tests/integration/test_data_rows.py b/libs/labelbox/tests/integration/test_data_rows.py index 672afe85d..30a462f80 100644 --- a/libs/labelbox/tests/integration/test_data_rows.py +++ b/libs/labelbox/tests/integration/test_data_rows.py @@ -2,12 +2,13 @@ import uuid from datetime import datetime import json +import requests +import os -from labelbox.schema.media_type import MediaType - +from unittest.mock import patch import pytest -import requests +from labelbox.schema.media_type import MediaType from labelbox import DataRow, AssetAttachment from labelbox.exceptions import MalformedQueryException from labelbox.schema.task import Task @@ -171,56 +172,110 @@ def test_lookup_data_rows(client, dataset): def test_data_row_bulk_creation(dataset, rand_gen, image_url): client = dataset.client + data_rows = [] assert len(list(dataset.data_rows())) == 0 - # Test creation using URL - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url - }, - { - "row_data": image_url - }, - ]) - assert task in client.get_user().created_tasks() - task.wait_till_done() - assert task.status == "COMPLETE" + try: + with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', + new=1): # Force chunking + # Test creation using URL + task = dataset.create_data_rows([ + { + DataRow.row_data: image_url + }, + { + "row_data": image_url + }, + ], + file_upload_thread_count=2) + task.wait_till_done() + assert task.has_errors() is False + assert task.status == "COMPLETE" - data_rows = list(dataset.data_rows()) - assert len(data_rows) == 2 - assert {data_row.row_data for data_row in data_rows} == {image_url} - assert {data_row.global_key for data_row in data_rows} == {None} + results = task.result + assert len(results) == 2 + row_data = [result["row_data"] for result in results] + assert row_data == [image_url, image_url] - data_rows = list(dataset.data_rows(from_cursor=data_rows[0].uid)) - assert len(data_rows) == 1 + data_rows = list(dataset.data_rows()) + assert len(data_rows) == 2 + assert {data_row.row_data for data_row in data_rows} == {image_url} + assert {data_row.global_key for data_row in data_rows} == {None} - # Test creation using file name - with NamedTemporaryFile() as fp: - data = rand_gen(str).encode() - fp.write(data) - fp.flush() - task = dataset.create_data_rows([fp.name]) + data_rows = list(dataset.data_rows(from_cursor=data_rows[0].uid)) + assert len(data_rows) == 1 + + finally: + for dr in data_rows: + dr.delete() + + +@pytest.fixture +def local_image_file(image_url) -> NamedTemporaryFile: + response = requests.get(image_url, stream=True) + response.raise_for_status() + + with NamedTemporaryFile(delete=False) as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + + yield f # Return the path to the temp file + + os.remove(f.name) + + +def test_data_row_bulk_creation_from_file(dataset, local_image_file, image_url): + with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', + new=1): # Force chunking + task = dataset.create_data_rows( + [local_image_file.name, local_image_file.name]) task.wait_till_done() assert task.status == "COMPLETE" + assert len(task.result) == 2 + assert task.has_errors() is False + results = task.result + row_data = [result["row_data"] for result in results] + assert len(row_data) == 2 + +def test_data_row_bulk_creation_from_row_data_file_external_id( + dataset, local_image_file, image_url): + with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', + new=1): # Force chunking task = dataset.create_data_rows([{ - "row_data": fp.name, + "row_data": local_image_file.name, 'external_id': 'some_name' + }, { + "row_data": image_url, + 'external_id': 'some_name2' }]) task.wait_till_done() assert task.status == "COMPLETE" - - task = dataset.create_data_rows([{"row_data": fp.name}]) + assert len(task.result) == 2 + assert task.has_errors() is False + results = task.result + row_data = [result["row_data"] for result in results] + assert len(row_data) == 2 + assert image_url in row_data + + +def test_data_row_bulk_creation_from_row_data_file(dataset, rand_gen, + local_image_file, image_url): + with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', + new=1): # Force chunking + task = dataset.create_data_rows([{ + "row_data": local_image_file.name + }, { + "row_data": local_image_file.name + }]) task.wait_till_done() assert task.status == "COMPLETE" - - data_rows = list(dataset.data_rows()) - assert len(data_rows) == 5 - url = ({data_row.row_data for data_row in data_rows} - {image_url}).pop() - assert requests.get(url).content == data - - for dr in data_rows: - dr.delete() + assert len(task.result) == 2 + assert task.has_errors() is False + results = task.result + row_data = [result["row_data"] for result in results] + assert len(row_data) == 2 @pytest.mark.slow @@ -844,6 +899,7 @@ def test_create_data_rows_result(client, dataset, image_url): DataRow.external_id: "row1", }, ]) + task.wait_till_done() assert task.errors is None for result in task.result: client.get_data_row(result['id']) @@ -918,8 +974,14 @@ def test_data_row_bulk_creation_with_same_global_keys(dataset, sample_image, 'message'] == f"Duplicate global key: '{global_key_1}'" assert task.failed_data_rows[0]['failedDataRows'][0][ 'externalId'] == sample_image - assert task.created_data_rows[0]['externalId'] == sample_image - assert task.created_data_rows[0]['globalKey'] == global_key_1 + assert task.created_data_rows[0]['external_id'] == sample_image + assert task.created_data_rows[0]['global_key'] == global_key_1 + + assert len(task.errors) == 1 + assert task.has_errors() is True + + all_results = task.result + assert len(all_results) == 1 def test_data_row_delete_and_create_with_same_global_key( diff --git a/libs/labelbox/tests/integration/test_data_rows_upsert.py b/libs/labelbox/tests/integration/test_data_rows_upsert.py index c73ae5e5c..accde6dd7 100644 --- a/libs/labelbox/tests/integration/test_data_rows_upsert.py +++ b/libs/labelbox/tests/integration/test_data_rows_upsert.py @@ -208,9 +208,8 @@ def test_multiple_chunks(self, client, dataset, image_url): mocked_chunk_size = 3 with patch('labelbox.client.Client.upload_data', wraps=client.upload_data) as spy_some_function: - with patch( - 'labelbox.schema.dataset.Dataset._Dataset__upsert_chunk_size', - new=mocked_chunk_size): + with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', + new=mocked_chunk_size): task = dataset.upsert_data_rows([{ 'row_data': image_url } for i in range(10)]) diff --git a/libs/labelbox/tests/integration/test_dataset.py b/libs/labelbox/tests/integration/test_dataset.py index de2f15820..82378bee1 100644 --- a/libs/labelbox/tests/integration/test_dataset.py +++ b/libs/labelbox/tests/integration/test_dataset.py @@ -3,6 +3,7 @@ from labelbox import Dataset from labelbox.exceptions import ResourceNotFoundError, MalformedQueryException, InvalidQueryError from labelbox.schema.dataset import MAX_DATAROW_PER_API_OPERATION +from labelbox.schema.internal.data_row_uploader import DataRowUploader def test_dataset(client, rand_gen): @@ -153,7 +154,10 @@ def test_create_descriptor_file(dataset): with mock.patch.object(dataset.client, 'upload_data', wraps=dataset.client.upload_data) as upload_data_spy: - dataset._create_descriptor_file(items=[{'row_data': 'some text...'}]) + DataRowUploader.create_descriptor_file(dataset.client, + items=[{ + 'row_data': 'some text...' + }]) upload_data_spy.assert_called() call_args, call_kwargs = upload_data_spy.call_args_list[0][ 0], upload_data_spy.call_args_list[0][1] @@ -162,12 +166,3 @@ def test_create_descriptor_file(dataset): 'content_type': 'application/json', 'filename': 'json_import.json' } - - -def test_max_dataset_datarow_upload(dataset, image_url, rand_gen): - external_id = str(rand_gen) - items = [dict(row_data=image_url, external_id=external_id) - ] * (MAX_DATAROW_PER_API_OPERATION + 1) - - with pytest.raises(MalformedQueryException): - dataset.create_data_rows(items) diff --git a/libs/labelbox/tests/integration/test_task.py b/libs/labelbox/tests/integration/test_task.py index 66b34d456..07309928a 100644 --- a/libs/labelbox/tests/integration/test_task.py +++ b/libs/labelbox/tests/integration/test_task.py @@ -58,14 +58,16 @@ def test_task_success_json(dataset, image_url, snapshot): 'test_task.test_task_success_json.json') assert len(task.result) + @pytest.mark.export_v1("export_v1 test remove later") def test_task_success_label_export(client, configured_project_with_label): project, _, _, _ = configured_project_with_label + # TODO: Move to export_v2 project.export_labels() user = client.get_user() task = None for task in user.created_tasks(): - if task.name != 'JSON Import': + if task.name != 'JSON Import' and task.type != 'adv-upsert-data-rows': break with pytest.raises(ValueError) as exc_info: diff --git a/libs/labelbox/tests/unit/test_data_row_upsert_data.py b/libs/labelbox/tests/unit/test_data_row_upsert_data.py new file mode 100644 index 000000000..9f6eb1400 --- /dev/null +++ b/libs/labelbox/tests/unit/test_data_row_upsert_data.py @@ -0,0 +1,66 @@ +import pytest +from labelbox.schema.internal.data_row_upsert_item import (DataRowUpsertItem) +from labelbox.schema.identifiable import UniqueId, GlobalKey +from labelbox.schema.asset_attachment import AttachmentType + + +@pytest.fixture +def data_row_create_items(): + dataset_id = 'test_dataset' + items = [ + { + "row_data": "http://my_site.com/photos/img_01.jpg", + "global_key": "global_key1", + "external_id": "ex_id1", + "attachments": [{ + "type": AttachmentType.RAW_TEXT, + "name": "att1", + "value": "test1" + }], + "metadata": [{ + "name": "tag", + "value": "tag value" + },] + }, + ] + return dataset_id, items + + +@pytest.fixture +def data_row_update_items(): + dataset_id = 'test_dataset' + items = [ + { + "key": GlobalKey("global_key1"), + "global_key": "global_key1_updated" + }, + { + "key": UniqueId('unique_id1'), + "external_id": "ex_id1_updated" + }, + ] + return dataset_id, items + + +def test_data_row_upsert_items(data_row_create_items, data_row_update_items): + dataset_id, create_items = data_row_create_items + dataset_id, update_items = data_row_update_items + items = create_items + update_items + result = DataRowUpsertItem.build(dataset_id, items, (UniqueId, GlobalKey)) + assert len(result) == len(items) + for item, res in zip(items, result): + assert res.payload == item + + +def test_data_row_create_items(data_row_create_items): + dataset_id, items = data_row_create_items + result = DataRowUpsertItem.build(dataset_id, items) + assert len(result) == len(items) + for item, res in zip(items, result): + assert res.payload == item + + +def test_data_row_create_items_not_updateable(data_row_update_items): + dataset_id, items = data_row_update_items + with pytest.raises(ValueError): + DataRowUpsertItem.build(dataset_id, items, ())