diff --git a/libs/labelbox/src/labelbox/schema/dataset.py b/libs/labelbox/src/labelbox/schema/dataset.py index 4f74e9b7f..9d02604d4 100644 --- a/libs/labelbox/src/labelbox/schema/dataset.py +++ b/libs/labelbox/src/labelbox/schema/dataset.py @@ -33,9 +33,10 @@ from labelbox.schema.user import User from labelbox.schema.iam_integration import IAMIntegration from labelbox.schema.internal.data_row_upsert_item import (DataRowUpsertItem) -from labelbox.schema.internal.data_row_uploader import DataRowUploader +import labelbox.schema.internal.data_row_uploader as data_row_uploader +from labelbox.schema.internal.descriptor_file_creator import DescriptorFileCreator from labelbox.schema.internal.datarow_upload_constants import ( - MAX_DATAROW_PER_API_OPERATION, FILE_UPLOAD_THREAD_COUNT, UPSERT_CHUNK_SIZE) + FILE_UPLOAD_THREAD_COUNT, UPSERT_CHUNK_SIZE_BYTES) logger = logging.getLogger(__name__) @@ -53,7 +54,6 @@ class Dataset(DbObject, Updateable, Deletable): created_by (Relationship): `ToOne` relationship to User organization (Relationship): `ToOne` relationship to Organization """ - __upsert_chunk_size: Final = UPSERT_CHUNK_SIZE name = Field.String("name") description = Field.String("description") @@ -240,10 +240,8 @@ def create_data_rows_sync(self, items) -> None: f"Dataset.create_data_rows_sync() supports a max of {max_data_rows_supported} data rows." " For larger imports use the async function Dataset.create_data_rows()" ) - descriptor_url = DataRowUploader.create_descriptor_file( - self.client, - items, - max_attachments_per_data_row=max_attachments_per_data_row) + descriptor_url = DescriptorFileCreator(self.client).create_one( + items, max_attachments_per_data_row=max_attachments_per_data_row) dataset_param = "datasetId" url_param = "jsonUrl" query_str = """mutation AppendRowsToDatasetSyncPyApi($%s: ID!, $%s: String!){ @@ -264,7 +262,7 @@ def create_data_rows(self, Use this instead of `Dataset.create_data_rows_sync` uploads for batches that contain more than 1000 data rows. Args: - items (iterable of (dict or str)): See the docstring for `DataRowUploader.create_descriptor_file` for more information + items (iterable of (dict or str)) Returns: Task representing the data import on the server side. The Task @@ -619,11 +617,11 @@ def _exec_upsert_data_rows( file_upload_thread_count: int = FILE_UPLOAD_THREAD_COUNT ) -> "DataUpsertTask": - manifest = DataRowUploader.upload_in_chunks( + manifest = data_row_uploader.upload_in_chunks( client=self.client, specs=specs, file_upload_thread_count=file_upload_thread_count, - upsert_chunk_size=UPSERT_CHUNK_SIZE) + max_chunk_size_bytes=UPSERT_CHUNK_SIZE_BYTES) data = json.dumps(manifest.dict()).encode("utf-8") manifest_uri = self.client.upload_data(data, diff --git a/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py b/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py index 9be4e2ffd..41b3d9752 100644 --- a/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py +++ b/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py @@ -1,287 +1,33 @@ -import json -import os from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Iterable, List +from typing import List -from labelbox.exceptions import InvalidQueryError -from labelbox.exceptions import InvalidAttributeError -from labelbox.exceptions import MalformedQueryException -from labelbox.orm.model import Entity -from labelbox.orm.model import Field -from labelbox.schema.embedding import EmbeddingVector -from labelbox.pydantic_compat import BaseModel -from labelbox.schema.internal.datarow_upload_constants import ( - MAX_DATAROW_PER_API_OPERATION, FILE_UPLOAD_THREAD_COUNT) +from labelbox import pydantic_compat from labelbox.schema.internal.data_row_upsert_item import DataRowUpsertItem +from labelbox.schema.internal.descriptor_file_creator import DescriptorFileCreator -class UploadManifest(BaseModel): +class UploadManifest(pydantic_compat.BaseModel): source: str item_count: int chunk_uris: List[str] -class DataRowUploader: +SOURCE_SDK = "SDK" - @staticmethod - def create_descriptor_file(client, - items, - max_attachments_per_data_row=None, - is_upsert=False): - """ - This function is shared by `Dataset.create_data_rows`, `Dataset.create_data_rows_sync` and `Dataset.update_data_rows`. - It is used to prepare the input file. The user defined input is validated, processed, and json stringified. - Finally the json data is uploaded to gcs and a uri is returned. This uri can be passed as a parameter to a mutation that uploads data rows - Each element in `items` can be either a `str` or a `dict`. If - it is a `str`, then it is interpreted as a local file path. The file - is uploaded to Labelbox and a DataRow referencing it is created. +def upload_in_chunks(client, specs: List[DataRowUpsertItem], + file_upload_thread_count: int, + max_chunk_size_bytes: int) -> UploadManifest: + empty_specs = list(filter(lambda spec: spec.is_empty(), specs)) - If an item is a `dict`, then it could support one of the two following structures - 1. For static imagery, video, and text it should map `DataRow` field names to values. - At the minimum an `items` passed as a `dict` must contain a `row_data` key and value. - If the value for row_data is a local file path and the path exists, - then the local file will be uploaded to labelbox. + if empty_specs: + ids = list(map(lambda spec: spec.id.get("value"), empty_specs)) + raise ValueError(f"The following items have an empty payload: {ids}") - 2. For tiled imagery the dict must match the import structure specified in the link below - https://docs.labelbox.com/data-model/en/index-en#tiled-imagery-import + chunk_uris = DescriptorFileCreator(client).create( + specs, max_chunk_size_bytes=max_chunk_size_bytes) - >>> dataset.create_data_rows([ - >>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"}, - >>> {DataRow.row_data:"/path/to/file1.jpg"}, - >>> "path/to/file2.jpg", - >>> {DataRow.row_data: {"tileLayerUrl" : "http://", ...}} - >>> {DataRow.row_data: {"type" : ..., 'version' : ..., 'messages' : [...]}} - >>> ]) - - For an example showing how to upload tiled data_rows see the following notebook: - https://github.com/Labelbox/labelbox-python/blob/ms/develop/model_assisted_labeling/tiled_imagery_mal.ipynb - - Args: - items (iterable of (dict or str)): See above for details. - max_attachments_per_data_row (Optional[int]): Param used during attachment validation to determine - if the user has provided too many attachments. - - Returns: - uri (string): A reference to the uploaded json data. - - Raises: - InvalidQueryError: If the `items` parameter does not conform to - the specification above or if the server did not accept the - DataRow creation request (unknown reason). - InvalidAttributeError: If there are fields in `items` not valid for - a DataRow. - ValueError: When the upload parameters are invalid - """ - file_upload_thread_count = FILE_UPLOAD_THREAD_COUNT - DataRow = Entity.DataRow - AssetAttachment = Entity.AssetAttachment - - def upload_if_necessary(item): - if is_upsert and 'row_data' not in item: - # When upserting, row_data is not required - return item - row_data = item['row_data'] - if isinstance(row_data, str) and os.path.exists(row_data): - item_url = client.upload_file(row_data) - item['row_data'] = item_url - if 'external_id' not in item: - # Default `external_id` to local file name - item['external_id'] = row_data - return item - - def validate_attachments(item): - attachments = item.get('attachments') - if attachments: - if isinstance(attachments, list): - if max_attachments_per_data_row and len( - attachments) > max_attachments_per_data_row: - raise ValueError( - f"Max attachments number of supported attachments per data row is {max_attachments_per_data_row}." - f" Found {len(attachments)}. Condense multiple attachments into one with the HTML attachment type if necessary." - ) - for attachment in attachments: - AssetAttachment.validate_attachment_json(attachment) - else: - raise ValueError( - f"Attachments must be a list. Found {type(attachments)}" - ) - return attachments - - def validate_embeddings(item): - embeddings = item.get("embeddings") - if embeddings: - item["embeddings"] = [ - EmbeddingVector(**e).to_gql() for e in embeddings - ] - - def validate_conversational_data(conversational_data: list) -> None: - """ - Checks each conversational message for keys expected as per https://docs.labelbox.com/reference/text-conversational#sample-conversational-json - - Args: - conversational_data (list): list of dictionaries. - """ - - def check_message_keys(message): - accepted_message_keys = set([ - "messageId", "timestampUsec", "content", "user", "align", - "canLabel" - ]) - for key in message.keys(): - if not key in accepted_message_keys: - raise KeyError( - f"Invalid {key} key found! Accepted keys in messages list is {accepted_message_keys}" - ) - - if conversational_data and not isinstance(conversational_data, - list): - raise ValueError( - f"conversationalData must be a list. Found {type(conversational_data)}" - ) - - [check_message_keys(message) for message in conversational_data] - - def parse_metadata_fields(item): - metadata_fields = item.get('metadata_fields') - if metadata_fields: - mdo = client.get_data_row_metadata_ontology() - item['metadata_fields'] = mdo.parse_upsert_metadata( - metadata_fields) - - def format_row(item): - # Formats user input into a consistent dict structure - if isinstance(item, dict): - # Convert fields to strings - item = { - key.name if isinstance(key, Field) else key: value - for key, value in item.items() - } - elif isinstance(item, str): - # The main advantage of using a string over a dict is that the user is specifying - # that the file should exist locally. - # That info is lost after this section so we should check for it here. - if not os.path.exists(item): - raise ValueError(f"Filepath {item} does not exist.") - item = {"row_data": item, "external_id": item} - return item - - def validate_keys(item): - if not is_upsert and 'row_data' not in item: - raise InvalidQueryError( - "`row_data` missing when creating DataRow.") - - if isinstance(item.get('row_data'), - str) and item.get('row_data').startswith("s3:/"): - raise InvalidQueryError( - "row_data: s3 assets must start with 'https'.") - allowed_extra_fields = { - 'attachments', 'media_type', 'dataset_id', 'embeddings' - } - invalid_keys = set(item) - {f.name for f in DataRow.fields() - } - allowed_extra_fields - if invalid_keys: - raise InvalidAttributeError(DataRow, invalid_keys) - return item - - def format_legacy_conversational_data(item): - messages = item.pop("conversationalData") - version = item.pop("version", 1) - type = item.pop("type", "application/vnd.labelbox.conversational") - if "externalId" in item: - external_id = item.pop("externalId") - item["external_id"] = external_id - if "globalKey" in item: - global_key = item.pop("globalKey") - item["globalKey"] = global_key - validate_conversational_data(messages) - one_conversation = \ - { - "type": type, - "version": version, - "messages": messages - } - item["row_data"] = one_conversation - return item - - def convert_item(data_row_item): - if isinstance(data_row_item, DataRowUpsertItem): - item = data_row_item.payload - else: - item = data_row_item - - if "tileLayerUrl" in item: - validate_attachments(item) - return item - - if "conversationalData" in item: - format_legacy_conversational_data(item) - - # Convert all payload variations into the same dict format - item = format_row(item) - # Make sure required keys exist (and there are no extra keys) - validate_keys(item) - # Make sure attachments are valid - validate_attachments(item) - # Make sure embeddings are valid - validate_embeddings(item) - # Parse metadata fields if they exist - parse_metadata_fields(item) - # Upload any local file paths - item = upload_if_necessary(item) - - if isinstance(data_row_item, DataRowUpsertItem): - return {'id': data_row_item.id, 'payload': item} - else: - return item - - if not isinstance(items, Iterable): - raise ValueError( - f"Must pass an iterable to create_data_rows. Found {type(items)}" - ) - - if len(items) > MAX_DATAROW_PER_API_OPERATION: - raise MalformedQueryException( - f"Cannot create more than {MAX_DATAROW_PER_API_OPERATION} DataRows per function call." - ) - - with ThreadPoolExecutor(file_upload_thread_count) as executor: - futures = [executor.submit(convert_item, item) for item in items] - items = [future.result() for future in as_completed(futures)] - # Prepare and upload the desciptor file - data = json.dumps(items) - return client.upload_data(data, - content_type="application/json", - filename="json_import.json") - - @staticmethod - def upload_in_chunks(client, specs: List[DataRowUpsertItem], - file_upload_thread_count: int, - upsert_chunk_size: int) -> UploadManifest: - empty_specs = list(filter(lambda spec: spec.is_empty(), specs)) - - if empty_specs: - ids = list(map(lambda spec: spec.id.get("value"), empty_specs)) - raise ValueError( - f"The following items have an empty payload: {ids}") - - chunks = [ - specs[i:i + upsert_chunk_size] - for i in range(0, len(specs), upsert_chunk_size) - ] - - def _upload_chunk(chunk): - return DataRowUploader.create_descriptor_file(client, - chunk, - is_upsert=True) - - with ThreadPoolExecutor(file_upload_thread_count) as executor: - futures = [ - executor.submit(_upload_chunk, chunk) for chunk in chunks - ] - chunk_uris = [future.result() for future in as_completed(futures)] - - return UploadManifest(source="SDK", - item_count=len(specs), - chunk_uris=chunk_uris) + return UploadManifest(source=SOURCE_SDK, + item_count=len(specs), + chunk_uris=chunk_uris) diff --git a/libs/labelbox/src/labelbox/schema/internal/datarow_upload_constants.py b/libs/labelbox/src/labelbox/schema/internal/datarow_upload_constants.py index f4c919095..7fb9fd058 100644 --- a/libs/labelbox/src/labelbox/schema/internal/datarow_upload_constants.py +++ b/libs/labelbox/src/labelbox/schema/internal/datarow_upload_constants.py @@ -1,4 +1,3 @@ -MAX_DATAROW_PER_API_OPERATION = 150_000 FILE_UPLOAD_THREAD_COUNT = 20 -UPSERT_CHUNK_SIZE = 10_000 +UPSERT_CHUNK_SIZE_BYTES = 10_000_000 DOWNLOAD_RESULT_PAGE_SIZE = 5_000 diff --git a/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py b/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py new file mode 100644 index 000000000..4c7e95ee4 --- /dev/null +++ b/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py @@ -0,0 +1,315 @@ +import json +import os +import sys +from concurrent.futures import ThreadPoolExecutor, as_completed + +from typing import Iterable, List, Generator + +from labelbox.exceptions import InvalidQueryError +from labelbox.exceptions import InvalidAttributeError +from labelbox.exceptions import MalformedQueryException +from labelbox.orm.model import Entity +from labelbox.orm.model import Field +from labelbox.schema.embedding import EmbeddingVector +from labelbox.schema.internal.datarow_upload_constants import ( + FILE_UPLOAD_THREAD_COUNT) +from labelbox.schema.internal.data_row_upsert_item import DataRowUpsertItem + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from labelbox import Client + + +class DescriptorFileCreator: + """ + This class is used to upload a list of dict and return a url. + It will create multiple files if the size of upload the max_chunk_size in bytes, + upload the files to gcs in parallel, and return a list of urls + + Args: + client (Client): The client object + max_chunk_size_bytes (int): The maximum size of the file in bytes + """ + + def __init__(self, client: "Client"): + self.client = client + """" + This method is used to convert a list to json and upload it in a file to gcs. + It will create multiple files if the size of upload is greater than max_chunk_size_bytes in bytes, + It uploads the files to gcs in parallel, and return a list of urls + + Args: + items: The list to upload + is_upsert (bool): Whether the upload is an upsert + max_attachments_per_data_row (int): The maximum number of attachments per data row + max_chunk_size_bytes (int): The maximum size of the file in bytes + """ + + def create(self, + items, + max_attachments_per_data_row=None, + max_chunk_size_bytes=None) -> List[str]: + is_upsert = True # This class will only support upsert use cases + items = self._prepare_items_for_upload(items, + max_attachments_per_data_row, + is_upsert=is_upsert) + json_chunks = self._chunk_down_by_bytes(items, max_chunk_size_bytes) + with ThreadPoolExecutor(FILE_UPLOAD_THREAD_COUNT) as executor: + futures = [ + executor.submit(self.client.upload_data, chunk, + "application/json", "json_import.json") + for chunk in json_chunks + ] + return [future.result() for future in as_completed(futures)] + + def create_one(self, items, max_attachments_per_data_row=None) -> List[str]: + items = self._prepare_items_for_upload(items, + max_attachments_per_data_row) + # Prepare and upload the descriptor file + data = json.dumps(items) + return self.client.upload_data(data, + content_type="application/json", + filename="json_import.json") + + def _prepare_items_for_upload(self, + items, + max_attachments_per_data_row=None, + is_upsert=False): + """ + This function is used to prepare the input file. The user defined input is validated, processed, and json stringified. + Finally the json data is uploaded to gcs and a uri is returned. This uri can be passed as a parameter to a mutation that uploads data rows + + Each element in `items` can be either a `str` or a `dict`. If + it is a `str`, then it is interpreted as a local file path. The file + is uploaded to Labelbox and a DataRow referencing it is created. + + If an item is a `dict`, then it could support one of the two following structures + 1. For static imagery, video, and text it should map `DataRow` field names to values. + At the minimum an `items` passed as a `dict` must contain a `row_data` key and value. + If the value for row_data is a local file path and the path exists, + then the local file will be uploaded to labelbox. + + 2. For tiled imagery the dict must match the import structure specified in the link below + https://docs.labelbox.com/data-model/en/index-en#tiled-imagery-import + + >>> dataset.create_data_rows([ + >>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"}, + >>> {DataRow.row_data:"/path/to/file1.jpg"}, + >>> "path/to/file2.jpg", + >>> {DataRow.row_data: {"tileLayerUrl" : "http://", ...}} + >>> {DataRow.row_data: {"type" : ..., 'version' : ..., 'messages' : [...]}} + >>> ]) + + Args: + items (iterable of (dict or str)): See above for details. + max_attachments_per_data_row (Optional[int]): Param used during attachment validation to determine + if the user has provided too many attachments. + + Returns: + uri (string): A reference to the uploaded json data. + + Raises: + InvalidQueryError: If the `items` parameter does not conform to + the specification above or if the server did not accept the + DataRow creation request (unknown reason). + InvalidAttributeError: If there are fields in `items` not valid for + a DataRow. + ValueError: When the upload parameters are invalid + """ + file_upload_thread_count = FILE_UPLOAD_THREAD_COUNT + DataRow = Entity.DataRow + AssetAttachment = Entity.AssetAttachment + + def upload_if_necessary(item): + if is_upsert and 'row_data' not in item: + # When upserting, row_data is not required + return item + row_data = item['row_data'] + if isinstance(row_data, str) and os.path.exists(row_data): + item_url = self.client.upload_file(row_data) + item['row_data'] = item_url + if 'external_id' not in item: + # Default `external_id` to local file name + item['external_id'] = row_data + return item + + def validate_attachments(item): + attachments = item.get('attachments') + if attachments: + if isinstance(attachments, list): + if max_attachments_per_data_row and len( + attachments) > max_attachments_per_data_row: + raise ValueError( + f"Max attachments number of supported attachments per data row is {max_attachments_per_data_row}." + f" Found {len(attachments)}. Condense multiple attachments into one with the HTML attachment type if necessary." + ) + for attachment in attachments: + AssetAttachment.validate_attachment_json(attachment) + else: + raise ValueError( + f"Attachments must be a list. Found {type(attachments)}" + ) + return attachments + + def validate_embeddings(item): + embeddings = item.get("embeddings") + if embeddings: + item["embeddings"] = [ + EmbeddingVector(**e).to_gql() for e in embeddings + ] + + def validate_conversational_data(conversational_data: list) -> None: + """ + Checks each conversational message for keys expected as per https://docs.labelbox.com/reference/text-conversational#sample-conversational-json + + Args: + conversational_data (list): list of dictionaries. + """ + + def check_message_keys(message): + accepted_message_keys = set([ + "messageId", "timestampUsec", "content", "user", "align", + "canLabel" + ]) + for key in message.keys(): + if not key in accepted_message_keys: + raise KeyError( + f"Invalid {key} key found! Accepted keys in messages list is {accepted_message_keys}" + ) + + if conversational_data and not isinstance(conversational_data, + list): + raise ValueError( + f"conversationalData must be a list. Found {type(conversational_data)}" + ) + + [check_message_keys(message) for message in conversational_data] + + def parse_metadata_fields(item): + metadata_fields = item.get('metadata_fields') + if metadata_fields: + mdo = self.client.get_data_row_metadata_ontology() + item['metadata_fields'] = mdo.parse_upsert_metadata( + metadata_fields) + + def format_row(item): + # Formats user input into a consistent dict structure + if isinstance(item, dict): + # Convert fields to strings + item = { + key.name if isinstance(key, Field) else key: value + for key, value in item.items() + } + elif isinstance(item, str): + # The main advantage of using a string over a dict is that the user is specifying + # that the file should exist locally. + # That info is lost after this section so we should check for it here. + if not os.path.exists(item): + raise ValueError(f"Filepath {item} does not exist.") + item = {"row_data": item, "external_id": item} + return item + + def validate_keys(item): + if not is_upsert and 'row_data' not in item: + raise InvalidQueryError( + "`row_data` missing when creating DataRow.") + + if isinstance(item.get('row_data'), + str) and item.get('row_data').startswith("s3:/"): + raise InvalidQueryError( + "row_data: s3 assets must start with 'https'.") + allowed_extra_fields = { + 'attachments', 'media_type', 'dataset_id', 'embeddings' + } + invalid_keys = set(item) - {f.name for f in DataRow.fields() + } - allowed_extra_fields + if invalid_keys: + raise InvalidAttributeError(DataRow, invalid_keys) + return item + + def format_legacy_conversational_data(item): + messages = item.pop("conversationalData") + version = item.pop("version", 1) + type = item.pop("type", "application/vnd.labelbox.conversational") + if "externalId" in item: + external_id = item.pop("externalId") + item["external_id"] = external_id + if "globalKey" in item: + global_key = item.pop("globalKey") + item["globalKey"] = global_key + validate_conversational_data(messages) + one_conversation = \ + { + "type": type, + "version": version, + "messages": messages + } + item["row_data"] = one_conversation + return item + + def convert_item(data_row_item): + if isinstance(data_row_item, DataRowUpsertItem): + item = data_row_item.payload + else: + item = data_row_item + + if "tileLayerUrl" in item: + validate_attachments(item) + return item + + if "conversationalData" in item: + format_legacy_conversational_data(item) + + # Convert all payload variations into the same dict format + item = format_row(item) + # Make sure required keys exist (and there are no extra keys) + validate_keys(item) + # Make sure attachments are valid + validate_attachments(item) + # Make sure embeddings are valid + validate_embeddings(item) + # Parse metadata fields if they exist + parse_metadata_fields(item) + # Upload any local file paths + item = upload_if_necessary(item) + + if isinstance(data_row_item, DataRowUpsertItem): + return {'id': data_row_item.id, 'payload': item} + else: + return item + + if not isinstance(items, Iterable): + raise ValueError( + f"Must pass an iterable to create_data_rows. Found {type(items)}" + ) + + with ThreadPoolExecutor(file_upload_thread_count) as executor: + futures = [executor.submit(convert_item, item) for item in items] + items = [future.result() for future in as_completed(futures)] + + return items + + def _chunk_down_by_bytes(self, items: List[dict], + max_chunk_size: int) -> Generator[str, None, None]: + """ + Recursively chunks down a list of items into smaller lists until each list is less than or equal to max_chunk_size bytes + NOTE: if one data row is larger than max_chunk_size, it will be returned as one chunk + + Returns: + Generator[str, None, None]: A generator that yields a json string + """ + if not items: + return + data = json.dumps(items) + chunk_size = len(data.encode("utf-8")) + if chunk_size <= max_chunk_size: + yield data + return + + if len(items) == 1: + yield data + return + + half = len(items) // 2 + yield from self._chunk_down_by_bytes(items[:half], max_chunk_size) + yield from self._chunk_down_by_bytes(items[half:], max_chunk_size) diff --git a/libs/labelbox/src/labelbox/schema/task.py b/libs/labelbox/src/labelbox/schema/task.py index 83dea3d1f..650472fab 100644 --- a/libs/labelbox/src/labelbox/schema/task.py +++ b/libs/labelbox/src/labelbox/schema/task.py @@ -11,9 +11,7 @@ from labelbox.pagination import PaginatedCollection from labelbox.schema.internal.datarow_upload_constants import ( - MAX_DATAROW_PER_API_OPERATION, - DOWNLOAD_RESULT_PAGE_SIZE, -) + DOWNLOAD_RESULT_PAGE_SIZE,) if TYPE_CHECKING: from labelbox import User @@ -233,7 +231,6 @@ class DataUpsertTask(Task): """ Task class for data row upsert operations """ - MAX_DOWNLOAD_SIZE: Final = MAX_DATAROW_PER_API_OPERATION def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/libs/labelbox/tests/integration/test_data_rows.py b/libs/labelbox/tests/integration/test_data_rows.py index 30a462f80..c3c8dc9cb 100644 --- a/libs/labelbox/tests/integration/test_data_rows.py +++ b/libs/labelbox/tests/integration/test_data_rows.py @@ -176,18 +176,18 @@ def test_data_row_bulk_creation(dataset, rand_gen, image_url): assert len(list(dataset.data_rows())) == 0 try: - with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', - new=1): # Force chunking + payload = [ + { + DataRow.row_data: image_url + }, + { + "row_data": image_url + }, + ] + with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES', + new=300): # To make 2 chunks # Test creation using URL - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url - }, - { - "row_data": image_url - }, - ], - file_upload_thread_count=2) + task = dataset.create_data_rows(payload, file_upload_thread_count=2) task.wait_till_done() assert task.has_errors() is False assert task.status == "COMPLETE" @@ -226,8 +226,8 @@ def local_image_file(image_url) -> NamedTemporaryFile: def test_data_row_bulk_creation_from_file(dataset, local_image_file, image_url): - with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', - new=1): # Force chunking + with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES', + new=500): # Force chunking task = dataset.create_data_rows( [local_image_file.name, local_image_file.name]) task.wait_till_done() @@ -241,8 +241,8 @@ def test_data_row_bulk_creation_from_file(dataset, local_image_file, image_url): def test_data_row_bulk_creation_from_row_data_file_external_id( dataset, local_image_file, image_url): - with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', - new=1): # Force chunking + with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES', + new=500): # Force chunking task = dataset.create_data_rows([{ "row_data": local_image_file.name, 'external_id': 'some_name' @@ -262,8 +262,8 @@ def test_data_row_bulk_creation_from_row_data_file_external_id( def test_data_row_bulk_creation_from_row_data_file(dataset, rand_gen, local_image_file, image_url): - with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', - new=1): # Force chunking + with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES', + new=500): # Force chunking task = dataset.create_data_rows([{ "row_data": local_image_file.name }, { diff --git a/libs/labelbox/tests/integration/test_data_rows_upsert.py b/libs/labelbox/tests/integration/test_data_rows_upsert.py index accde6dd7..67b7d69c7 100644 --- a/libs/labelbox/tests/integration/test_data_rows_upsert.py +++ b/libs/labelbox/tests/integration/test_data_rows_upsert.py @@ -205,17 +205,17 @@ def test_update_metadata_with_upsert(self, client, all_inclusive_data_row, assert dr.metadata_fields[1]['value'] == "train" def test_multiple_chunks(self, client, dataset, image_url): - mocked_chunk_size = 3 + mocked_chunk_size = 300 with patch('labelbox.client.Client.upload_data', wraps=client.upload_data) as spy_some_function: - with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', + with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE_BYTES', new=mocked_chunk_size): task = dataset.upsert_data_rows([{ 'row_data': image_url } for i in range(10)]) task.wait_till_done() assert len(list(dataset.data_rows())) == 10 - assert spy_some_function.call_count == 5 # 4 chunks + manifest + assert spy_some_function.call_count == 11 # one per each data row + manifest first_call_args, _ = spy_some_function.call_args_list[0] first_chunk_content = first_call_args[0] @@ -231,7 +231,7 @@ def test_multiple_chunks(self, client, dataset, image_url): data = json.loads(manifest_content) assert data['source'] == "SDK" assert data['item_count'] == 10 - assert len(data['chunk_uris']) == 4 + assert len(data['chunk_uris']) == 10 def test_upsert_embedded_row_data(self, dataset): pdf_url = "https://lb-test-data.s3.us-west-1.amazonaws.com/document-samples/0801.3483.pdf" diff --git a/libs/labelbox/tests/integration/test_dataset.py b/libs/labelbox/tests/integration/test_dataset.py index 82378bee1..381c48fb2 100644 --- a/libs/labelbox/tests/integration/test_dataset.py +++ b/libs/labelbox/tests/integration/test_dataset.py @@ -1,9 +1,10 @@ import pytest import requests +from unittest.mock import MagicMock from labelbox import Dataset -from labelbox.exceptions import ResourceNotFoundError, MalformedQueryException, InvalidQueryError -from labelbox.schema.dataset import MAX_DATAROW_PER_API_OPERATION -from labelbox.schema.internal.data_row_uploader import DataRowUploader +from labelbox.exceptions import ResourceNotFoundError, InvalidQueryError + +from labelbox.schema.internal.descriptor_file_creator import DescriptorFileCreator def test_dataset(client, rand_gen): @@ -151,13 +152,12 @@ def test_bulk_conversation(dataset, sample_bulk_conversation: list) -> None: def test_create_descriptor_file(dataset): import unittest.mock as mock - with mock.patch.object(dataset.client, - 'upload_data', - wraps=dataset.client.upload_data) as upload_data_spy: - DataRowUploader.create_descriptor_file(dataset.client, - items=[{ - 'row_data': 'some text...' - }]) + client = MagicMock() + with mock.patch.object(client, 'upload_data', + wraps=client.upload_data) as upload_data_spy: + DescriptorFileCreator(client).create_one(items=[{ + 'row_data': 'some text...' + }]) upload_data_spy.assert_called() call_args, call_kwargs = upload_data_spy.call_args_list[0][ 0], upload_data_spy.call_args_list[0][1] diff --git a/libs/labelbox/tests/unit/test_unit_descriptor_file_creator.py b/libs/labelbox/tests/unit/test_unit_descriptor_file_creator.py new file mode 100644 index 000000000..630d80573 --- /dev/null +++ b/libs/labelbox/tests/unit/test_unit_descriptor_file_creator.py @@ -0,0 +1,56 @@ +import json + +from unittest.mock import MagicMock, Mock +import pytest + +from labelbox.schema.internal.descriptor_file_creator import DescriptorFileCreator + + +def test_chunk_down_by_bytes_row_too_large(): + client = MagicMock() + + descriptor_file_creator = DescriptorFileCreator(client) + + chunk = [{"row_data": "a"}] + max_chunk_size_bytes = 1 + + res = descriptor_file_creator._chunk_down_by_bytes(chunk, + max_chunk_size_bytes) + assert [x for x in res] == [json.dumps([{"row_data": "a"}])] + + +def test_chunk_down_by_bytes_more_chunks(): + client = MagicMock() + + descriptor_file_creator = DescriptorFileCreator(client) + + chunk = [{"row_data": "a"}, {"row_data": "b"}] + max_chunk_size_bytes = len(json.dumps(chunk).encode("utf-8")) - 1 + + res = descriptor_file_creator._chunk_down_by_bytes(chunk, + max_chunk_size_bytes) + assert [x for x in res] == [ + json.dumps([{ + "row_data": "a" + }]), json.dumps([{ + "row_data": "b" + }]) + ] + + +def test_chunk_down_by_bytes_one_chunk(): + client = MagicMock() + + descriptor_file_creator = DescriptorFileCreator(client) + + chunk = [{"row_data": "a"}, {"row_data": "b"}] + max_chunk_size_bytes = len(json.dumps(chunk).encode("utf-8")) + + res = descriptor_file_creator._chunk_down_by_bytes(chunk, + max_chunk_size_bytes) + assert [x for x in res + ] == [json.dumps([{ + "row_data": "a" + }, { + "row_data": "b" + }])]