From 65066b77142e930359be05b152af59015be30e04 Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Wed, 15 May 2024 11:34:33 -0700 Subject: [PATCH 1/9] Refactor upsert code so that it can be reused for create Extract spec generation Extract data row upload logic Extract chunk generation and upload Update create data row Rename DatarowUploader --> DataRowUploader Reuse upsert backend for create_data_rows Add DataUpsertTask --- libs/labelbox/src/labelbox/client.py | 1 + libs/labelbox/src/labelbox/schema/dataset.py | 376 ++---------------- .../schema/internal/data_row_create_upsert.py | 66 +++ .../schema/internal/data_row_uploader.py | 294 ++++++++++++++ .../internal/datarow_upload_constants.py | 3 + libs/labelbox/src/labelbox/schema/task.py | 79 +++- .../tests/integration/test_data_rows.py | 36 +- .../tests/integration/test_dataset.py | 6 +- .../tests/unit/test_data_row_upsert_data.py | 64 +++ 9 files changed, 574 insertions(+), 351 deletions(-) create mode 100644 libs/labelbox/src/labelbox/schema/internal/data_row_create_upsert.py create mode 100644 libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py create mode 100644 libs/labelbox/src/labelbox/schema/internal/datarow_upload_constants.py create mode 100644 libs/labelbox/tests/unit/test_data_row_upsert_data.py diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index 5ad843b89..a2fb09186 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -407,6 +407,7 @@ def upload_data(self, }), "map": (None, json.dumps({"1": ["variables.file"]})), } + response = requests.post( self.endpoint, headers={"authorization": "Bearer %s" % self.api_key}, diff --git a/libs/labelbox/src/labelbox/schema/dataset.py b/libs/labelbox/src/labelbox/schema/dataset.py index 0bc5b74b1..033c337e7 100644 --- a/libs/labelbox/src/labelbox/schema/dataset.py +++ b/libs/labelbox/src/labelbox/schema/dataset.py @@ -29,27 +29,21 @@ from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params from labelbox.schema.export_task import ExportTask from labelbox.schema.identifiable import UniqueId, GlobalKey -from labelbox.schema.task import Task +from labelbox.schema.task import Task, DataUpsertTask from labelbox.schema.user import User +<<<<<<< HEAD from labelbox.schema.iam_integration import IAMIntegration +======= +from labelbox.schema.internal.data_row_create_upsert import (DataRowItemBase, + DataRowUpsertItem, + DataRowCreateItem) +from labelbox.schema.internal.data_row_uploader import DataRowUploader +from labelbox.schema.internal.datarow_upload_constants import ( + MAX_DATAROW_PER_API_OPERATION, FILE_UPLOAD_THREAD_COUNT, UPSERT_CHUNK_SIZE) +>>>>>>> 58e48ccf (Refactor upsert code so that it can be reused for create) logger = logging.getLogger(__name__) -MAX_DATAROW_PER_API_OPERATION = 150_000 - - -class DataRowUpsertItem(BaseModel): - id: dict - payload: dict - - def is_empty(self) -> bool: - """ - The payload is considered empty if it's actually empty or the only key is `dataset_id`. - :return: bool - """ - return (not self.payload or - len(self.payload.keys()) == 1 and "dataset_id" in self.payload) - class Dataset(DbObject, Updateable, Deletable): """ A Dataset is a collection of DataRows. @@ -64,7 +58,7 @@ class Dataset(DbObject, Updateable, Deletable): created_by (Relationship): `ToOne` relationship to User organization (Relationship): `ToOne` relationship to Organization """ - __upsert_chunk_size: Final = 10_000 + __upsert_chunk_size: Final = UPSERT_CHUNK_SIZE name = Field.String("name") description = Field.String("description") @@ -251,8 +245,10 @@ def create_data_rows_sync(self, items) -> None: f"Dataset.create_data_rows_sync() supports a max of {max_data_rows_supported} data rows." " For larger imports use the async function Dataset.create_data_rows()" ) - descriptor_url = self._create_descriptor_file( - items, max_attachments_per_data_row=max_attachments_per_data_row) + descriptor_url = DataRowUploader.create_descriptor_file( + self.client, + items, + max_attachments_per_data_row=max_attachments_per_data_row) dataset_param = "datasetId" url_param = "jsonUrl" query_str = """mutation AppendRowsToDatasetSyncPyApi($%s: ID!, $%s: String!){ @@ -264,13 +260,16 @@ def create_data_rows_sync(self, items) -> None: url_param: descriptor_url }) - def create_data_rows(self, items) -> "Task": + def create_data_rows( + self, + items, + file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT) -> "Task": """ Asynchronously bulk upload data rows Use this instead of `Dataset.create_data_rows_sync` uploads for batches that contain more than 1000 data rows. Args: - items (iterable of (dict or str)): See the docstring for `Dataset._create_descriptor_file` for more information + items (iterable of (dict or str)): See the docstring for `DataRowUploader.create_descriptor_file` for more information Returns: Task representing the data import on the server side. The Task @@ -286,270 +285,9 @@ def create_data_rows(self, items) -> "Task": a DataRow. ValueError: When the upload parameters are invalid """ - descriptor_url = self._create_descriptor_file(items) - # Create data source - dataset_param = "datasetId" - url_param = "jsonUrl" - query_str = """mutation AppendRowsToDatasetPyApi($%s: ID!, $%s: String!){ - appendRowsToDataset(data:{datasetId: $%s, jsonFileUrl: $%s} - ){ taskId accepted errorMessage } } """ % (dataset_param, url_param, - dataset_param, url_param) - res = self.client.execute(query_str, { - dataset_param: self.uid, - url_param: descriptor_url - }) - res = res["appendRowsToDataset"] - if not res["accepted"]: - msg = res['errorMessage'] - raise InvalidQueryError( - f"Server did not accept DataRow creation request. {msg}") - - # Fetch and return the task. - task_id = res["taskId"] - user: User = self.client.get_user() - tasks: List[Task] = list( - user.created_tasks(where=Entity.Task.uid == task_id)) - # Cache user in a private variable as the relationship can't be - # resolved due to server-side limitations (see Task.created_by) - # for more info. - if len(tasks) != 1: - raise ResourceNotFoundError(Entity.Task, task_id) - task: Task = tasks[0] - task._user = user - return task - - def _create_descriptor_file(self, - items, - max_attachments_per_data_row=None, - is_upsert=False): - """ - This function is shared by both `Dataset.create_data_rows` and `Dataset.create_data_rows_sync` - to prepare the input file. The user defined input is validated, processed, and json stringified. - Finally the json data is uploaded to gcs and a uri is returned. This uri can be passed to - - - - Each element in `items` can be either a `str` or a `dict`. If - it is a `str`, then it is interpreted as a local file path. The file - is uploaded to Labelbox and a DataRow referencing it is created. - - If an item is a `dict`, then it could support one of the two following structures - 1. For static imagery, video, and text it should map `DataRow` field names to values. - At the minimum an `item` passed as a `dict` must contain a `row_data` key and value. - If the value for row_data is a local file path and the path exists, - then the local file will be uploaded to labelbox. - - 2. For tiled imagery the dict must match the import structure specified in the link below - https://docs.labelbox.com/data-model/en/index-en#tiled-imagery-import - - >>> dataset.create_data_rows([ - >>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"}, - >>> {DataRow.row_data:"/path/to/file1.jpg"}, - >>> "path/to/file2.jpg", - >>> {DataRow.row_data: {"tileLayerUrl" : "http://", ...}} - >>> {DataRow.row_data: {"type" : ..., 'version' : ..., 'messages' : [...]}} - >>> ]) - - For an example showing how to upload tiled data_rows see the following notebook: - https://github.com/Labelbox/labelbox-python/blob/ms/develop/model_assisted_labeling/tiled_imagery_mal.ipynb - - Args: - items (iterable of (dict or str)): See above for details. - max_attachments_per_data_row (Optional[int]): Param used during attachment validation to determine - if the user has provided too many attachments. - - Returns: - uri (string): A reference to the uploaded json data. - - Raises: - InvalidQueryError: If the `items` parameter does not conform to - the specification above or if the server did not accept the - DataRow creation request (unknown reason). - InvalidAttributeError: If there are fields in `items` not valid for - a DataRow. - ValueError: When the upload parameters are invalid - """ - file_upload_thread_count = 20 - DataRow = Entity.DataRow - AssetAttachment = Entity.AssetAttachment - - def upload_if_necessary(item): - if is_upsert and 'row_data' not in item: - # When upserting, row_data is not required - return item - row_data = item['row_data'] - if isinstance(row_data, str) and os.path.exists(row_data): - item_url = self.client.upload_file(row_data) - item['row_data'] = item_url - if 'external_id' not in item: - # Default `external_id` to local file name - item['external_id'] = row_data - return item - - def validate_attachments(item): - attachments = item.get('attachments') - if attachments: - if isinstance(attachments, list): - if max_attachments_per_data_row and len( - attachments) > max_attachments_per_data_row: - raise ValueError( - f"Max attachments number of supported attachments per data row is {max_attachments_per_data_row}." - f" Found {len(attachments)}. Condense multiple attachments into one with the HTML attachment type if necessary." - ) - for attachment in attachments: - AssetAttachment.validate_attachment_json(attachment) - else: - raise ValueError( - f"Attachments must be a list. Found {type(attachments)}" - ) - return attachments - - def validate_embeddings(item): - embeddings = item.get("embeddings") - if embeddings: - item["embeddings"] = [ - EmbeddingVector(**e).to_gql() for e in embeddings - ] - - def validate_conversational_data(conversational_data: list) -> None: - """ - Checks each conversational message for keys expected as per https://docs.labelbox.com/reference/text-conversational#sample-conversational-json - - Args: - conversational_data (list): list of dictionaries. - """ - - def check_message_keys(message): - accepted_message_keys = set([ - "messageId", "timestampUsec", "content", "user", "align", - "canLabel" - ]) - for key in message.keys(): - if not key in accepted_message_keys: - raise KeyError( - f"Invalid {key} key found! Accepted keys in messages list is {accepted_message_keys}" - ) - - if conversational_data and not isinstance(conversational_data, - list): - raise ValueError( - f"conversationalData must be a list. Found {type(conversational_data)}" - ) - - [check_message_keys(message) for message in conversational_data] - - def parse_metadata_fields(item): - metadata_fields = item.get('metadata_fields') - if metadata_fields: - mdo = self.client.get_data_row_metadata_ontology() - item['metadata_fields'] = mdo.parse_upsert_metadata( - metadata_fields) - - def format_row(item): - # Formats user input into a consistent dict structure - if isinstance(item, dict): - # Convert fields to strings - item = { - key.name if isinstance(key, Field) else key: value - for key, value in item.items() - } - elif isinstance(item, str): - # The main advantage of using a string over a dict is that the user is specifying - # that the file should exist locally. - # That info is lost after this section so we should check for it here. - if not os.path.exists(item): - raise ValueError(f"Filepath {item} does not exist.") - item = {"row_data": item, "external_id": item} - return item - - def validate_keys(item): - if not is_upsert and 'row_data' not in item: - raise InvalidQueryError( - "`row_data` missing when creating DataRow.") - - if isinstance(item.get('row_data'), - str) and item.get('row_data').startswith("s3:/"): - raise InvalidQueryError( - "row_data: s3 assets must start with 'https'.") - allowed_extra_fields = { - 'attachments', 'media_type', 'dataset_id', 'embeddings' - } - invalid_keys = set(item) - {f.name for f in DataRow.fields() - } - allowed_extra_fields - if invalid_keys: - raise InvalidAttributeError(DataRow, invalid_keys) - return item - - def formatLegacyConversationalData(item): - messages = item.pop("conversationalData") - version = item.pop("version", 1) - type = item.pop("type", "application/vnd.labelbox.conversational") - if "externalId" in item: - external_id = item.pop("externalId") - item["external_id"] = external_id - if "globalKey" in item: - global_key = item.pop("globalKey") - item["globalKey"] = global_key - validate_conversational_data(messages) - one_conversation = \ - { - "type": type, - "version": version, - "messages": messages - } - item["row_data"] = one_conversation - return item - - def convert_item(data_row_item): - if isinstance(data_row_item, DataRowUpsertItem): - item = data_row_item.payload - else: - item = data_row_item - - if "tileLayerUrl" in item: - validate_attachments(item) - return item - - if "conversationalData" in item: - formatLegacyConversationalData(item) - - # Convert all payload variations into the same dict format - item = format_row(item) - # Make sure required keys exist (and there are no extra keys) - validate_keys(item) - # Make sure attachments are valid - validate_attachments(item) - # Make sure embeddings are valid - validate_embeddings(item) - # Parse metadata fields if they exist - parse_metadata_fields(item) - # Upload any local file paths - item = upload_if_necessary(item) - - if isinstance(data_row_item, DataRowUpsertItem): - return {'id': data_row_item.id, 'payload': item} - else: - return item - - if not isinstance(items, Iterable): - raise ValueError( - f"Must pass an iterable to create_data_rows. Found {type(items)}" - ) - - if len(items) > MAX_DATAROW_PER_API_OPERATION: - raise MalformedQueryException( - f"Cannot create more than {MAX_DATAROW_PER_API_OPERATION} DataRows per function call." - ) - - with ThreadPoolExecutor(file_upload_thread_count) as executor: - futures = [executor.submit(convert_item, item) for item in items] - items = [future.result() for future in as_completed(futures)] - # Prepare and upload the desciptor file - data = json.dumps(items) - return self.client.upload_data(data, - content_type="application/json", - filename="json_import.json") + specs = DataRowCreateItem.build(self.uid, items) + return self._exec_upsert_data_rows(specs, file_upload_thread_count) def data_rows_for_external_id(self, external_id, @@ -809,7 +547,10 @@ def _export( is_streamable = res["isStreamable"] return Task.get_task(self.client, task_id), is_streamable - def upsert_data_rows(self, items, file_upload_thread_count=20) -> "Task": + def upsert_data_rows( + self, + items, + file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT) -> "Task": """ Upserts data rows in this dataset. When "key" is provided, and it references an existing data row, an update will be performed. When "key" is not provided a new data row will be created. @@ -845,35 +586,20 @@ def upsert_data_rows(self, items, file_upload_thread_count=20) -> "Task": f"Cannot upsert more than {MAX_DATAROW_PER_API_OPERATION} DataRows per function call." ) - specs = self._convert_items_to_upsert_format(items) - - empty_specs = list(filter(lambda spec: spec.is_empty(), specs)) - - if empty_specs: - ids = list(map(lambda spec: spec.id.get("value"), empty_specs)) - raise ValueError( - f"The following items have an empty payload: {ids}") - - chunks = [ - specs[i:i + self.__upsert_chunk_size] - for i in range(0, len(specs), self.__upsert_chunk_size) - ] + specs = DataRowUpsertItem.build(self.uid, items) + return self._exec_upsert_data_rows(specs, file_upload_thread_count) - def _upload_chunk(_chunk): - return self._create_descriptor_file(_chunk, is_upsert=True) - - with ThreadPoolExecutor(file_upload_thread_count) as executor: - futures = [ - executor.submit(_upload_chunk, chunk) for chunk in chunks - ] - chunk_uris = [future.result() for future in as_completed(futures)] + def _exec_upsert_data_rows( + self, + specs: List[DataRowItemBase], + file_upload_thread_count: int = FILE_UPLOAD_THREAD_COUNT) -> "Task": + manifest = DataRowUploader.upload_in_chunks( + client=self.client, + specs=specs, + file_upload_thread_count=file_upload_thread_count, + upsert_chunk_size=UPSERT_CHUNK_SIZE) - manifest = { - "source": "SDK", - "item_count": len(specs), - "chunk_uris": chunk_uris - } - data = json.dumps(manifest).encode("utf-8") + data = json.dumps(manifest.to_dict()).encode("utf-8") manifest_uri = self.client.upload_data(data, content_type="application/json", filename="manifest.json") @@ -888,32 +614,10 @@ def _upload_chunk(_chunk): res = self.client.execute(query_str, {"manifestUri": manifest_uri}) res = res["upsertDataRows"] - task = Task(self.client, res) + task = DataUpsertTask(self.client, res) task._user = self.client.get_user() return task - def _convert_items_to_upsert_format(self, _items): - _upsert_items: List[DataRowUpsertItem] = [] - for item in _items: - # enforce current dataset's id for all specs - item['dataset_id'] = self.uid - key = item.pop('key', None) - if not key: - key = {'type': 'AUTO', 'value': ''} - elif isinstance(key, UniqueId): - key = {'type': 'ID', 'value': key.key} - elif isinstance(key, GlobalKey): - key = {'type': 'GKEY', 'value': key.key} - else: - raise ValueError( - f"Key must be an instance of UniqueId or GlobalKey, got: {type(item['key']).__name__}" - ) - item = { - k: v for k, v in item.items() if v is not None - } # remove None values - _upsert_items.append(DataRowUpsertItem(payload=item, id=key)) - return _upsert_items - def add_iam_integration(self, iam_integration: Union[str, IAMIntegration]) -> IAMIntegration: """ Sets the IAM integration for the dataset. IAM integration is used to sign URLs for data row assets. @@ -1008,4 +712,4 @@ def remove_iam_integration(self) -> None: if not response: raise ResourceNotFoundError(Dataset, {"id": self.uid}) - \ No newline at end of file + diff --git a/libs/labelbox/src/labelbox/schema/internal/data_row_create_upsert.py b/libs/labelbox/src/labelbox/schema/internal/data_row_create_upsert.py new file mode 100644 index 000000000..70656dbf5 --- /dev/null +++ b/libs/labelbox/src/labelbox/schema/internal/data_row_create_upsert.py @@ -0,0 +1,66 @@ +from abc import ABC, abstractmethod +from typing import List, Tuple, Optional + +from labelbox.schema.identifiable import UniqueId, GlobalKey +from labelbox.pydantic_compat import BaseModel + + +class DataRowItemBase(BaseModel, ABC): + id: dict + payload: dict + + @classmethod + @abstractmethod + def build( + cls, + dataset_id: str, + items: List[dict], + key_types: Optional[Tuple[type, ...]] = () + ) -> List["DataRowItemBase"]: + upload_items = [] + + for item in items: + # enforce current dataset's id for all specs + item['dataset_id'] = dataset_id + key = item.pop('key', None) + if not key: + key = {'type': 'AUTO', 'value': ''} + elif isinstance(key, key_types): + key = {'type': key.id_type.value, 'value': key.key} + else: + if not key_types: + raise ValueError( + f"Can not have a key for this item, got: {key}" + ) + raise ValueError( + f"Key must be an instance of {', '.join([t.__name__ for t in key_types])}, got: {type(item['key']).__name__}" + ) + item = { + k: v for k, v in item.items() if v is not None + } # remove None values + upload_items.append(cls(payload=item, id=key)) + return upload_items + + def is_empty(self) -> bool: + """ + The payload is considered empty if it's actually empty or the only key is `dataset_id`. + :return: bool + """ + return (not self.payload or + len(self.payload.keys()) == 1 and "dataset_id" in self.payload) + + +class DataRowUpsertItem(DataRowItemBase): + + @classmethod + def build(cls, dataset_id: str, + items: List[dict]) -> List["DataRowUpsertItem"]: + return super().build(dataset_id, items, (UniqueId, GlobalKey)) + + +class DataRowCreateItem(DataRowItemBase): + + @classmethod + def build(cls, dataset_id: str, + items: List[dict]) -> List["DataRowCreateItem"]: + return super().build(dataset_id, items, ()) diff --git a/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py b/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py new file mode 100644 index 000000000..907016ebb --- /dev/null +++ b/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py @@ -0,0 +1,294 @@ +import json +import os +from concurrent.futures import ThreadPoolExecutor, as_completed + +from typing import Iterable, List + +from labelbox.exceptions import InvalidQueryError +from labelbox.exceptions import InvalidAttributeError +from labelbox.exceptions import MalformedQueryException +from labelbox.orm.model import Entity +from labelbox.orm.model import Field +from labelbox.schema.embedding import EmbeddingVector +from labelbox.schema.internal.data_row_create_upsert import DataRowItemBase +from labelbox.schema.internal.datarow_upload_constants import MAX_DATAROW_PER_API_OPERATION + + +class UploadManifest: + + def __init__(self, source: str, item_count: int, chunk_uris: List[str]): + self.source = source + self.item_count = item_count + self.chunk_uris = chunk_uris + + def to_dict(self): + return { + "source": self.source, + "item_count": self.item_count, + "chunk_uris": self.chunk_uris + } + + +class DataRowUploader: + + @staticmethod + def create_descriptor_file(client, + items, + max_attachments_per_data_row=None, + is_upsert=False): + """ + This function is shared by `Dataset.create_data_rows`, `Dataset.create_data_rows_sync` and `Dataset.update_data_rows`. + It is used to prepare the input file. The user defined input is validated, processed, and json stringified. + Finally the json data is uploaded to gcs and a uri is returned. This uri can be passed to + + Each element in `items` can be either a `str` or a `dict`. If + it is a `str`, then it is interpreted as a local file path. The file + is uploaded to Labelbox and a DataRow referencing it is created. + + If an item is a `dict`, then it could support one of the two following structures + 1. For static imagery, video, and text it should map `DataRow` field names to values. + At the minimum an `item` passed as a `dict` must contain a `row_data` key and value. + If the value for row_data is a local file path and the path exists, + then the local file will be uploaded to labelbox. + + 2. For tiled imagery the dict must match the import structure specified in the link below + https://docs.labelbox.com/data-model/en/index-en#tiled-imagery-import + + >>> dataset.create_data_rows([ + >>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"}, + >>> {DataRow.row_data:"/path/to/file1.jpg"}, + >>> "path/to/file2.jpg", + >>> {DataRow.row_data: {"tileLayerUrl" : "http://", ...}} + >>> {DataRow.row_data: {"type" : ..., 'version' : ..., 'messages' : [...]}} + >>> ]) + + For an example showing how to upload tiled data_rows see the following notebook: + https://github.com/Labelbox/labelbox-python/blob/ms/develop/model_assisted_labeling/tiled_imagery_mal.ipynb + + Args: + items (iterable of (dict or str)): See above for details. + max_attachments_per_data_row (Optional[int]): Param used during attachment validation to determine + if the user has provided too many attachments. + + Returns: + uri (string): A reference to the uploaded json data. + + Raises: + InvalidQueryError: If the `items` parameter does not conform to + the specification above or if the server did not accept the + DataRow creation request (unknown reason). + InvalidAttributeError: If there are fields in `items` not valid for + a DataRow. + ValueError: When the upload parameters are invalid + """ + file_upload_thread_count = 20 + DataRow = Entity.DataRow + AssetAttachment = Entity.AssetAttachment + + def upload_if_necessary(item): + if is_upsert and 'row_data' not in item: + # When upserting, row_data is not required + return item + row_data = item['row_data'] + if isinstance(row_data, str) and os.path.exists(row_data): + item_url = client.upload_file(row_data) + item['row_data'] = item_url + if 'external_id' not in item: + # Default `external_id` to local file name + item['external_id'] = row_data + return item + + def validate_attachments(item): + attachments = item.get('attachments') + if attachments: + if isinstance(attachments, list): + if max_attachments_per_data_row and len( + attachments) > max_attachments_per_data_row: + raise ValueError( + f"Max attachments number of supported attachments per data row is {max_attachments_per_data_row}." + f" Found {len(attachments)}. Condense multiple attachments into one with the HTML attachment type if necessary." + ) + for attachment in attachments: + AssetAttachment.validate_attachment_json(attachment) + else: + raise ValueError( + f"Attachments must be a list. Found {type(attachments)}" + ) + return attachments + + def validate_embeddings(item): + embeddings = item.get("embeddings") + if embeddings: + item["embeddings"] = [ + EmbeddingVector(**e).to_gql() for e in embeddings + ] + + def validate_conversational_data(conversational_data: list) -> None: + """ + Checks each conversational message for keys expected as per https://docs.labelbox.com/reference/text-conversational#sample-conversational-json + + Args: + conversational_data (list): list of dictionaries. + """ + + def check_message_keys(message): + accepted_message_keys = set([ + "messageId", "timestampUsec", "content", "user", "align", + "canLabel" + ]) + for key in message.keys(): + if not key in accepted_message_keys: + raise KeyError( + f"Invalid {key} key found! Accepted keys in messages list is {accepted_message_keys}" + ) + + if conversational_data and not isinstance(conversational_data, + list): + raise ValueError( + f"conversationalData must be a list. Found {type(conversational_data)}" + ) + + [check_message_keys(message) for message in conversational_data] + + def parse_metadata_fields(item): + metadata_fields = item.get('metadata_fields') + if metadata_fields: + mdo = client.get_data_row_metadata_ontology() + item['metadata_fields'] = mdo.parse_upsert_metadata( + metadata_fields) + + def format_row(item): + # Formats user input into a consistent dict structure + if isinstance(item, dict): + # Convert fields to strings + item = { + key.name if isinstance(key, Field) else key: value + for key, value in item.items() + } + elif isinstance(item, str): + # The main advantage of using a string over a dict is that the user is specifying + # that the file should exist locally. + # That info is lost after this section so we should check for it here. + if not os.path.exists(item): + raise ValueError(f"Filepath {item} does not exist.") + item = {"row_data": item, "external_id": item} + return item + + def validate_keys(item): + if not is_upsert and 'row_data' not in item: + raise InvalidQueryError( + "`row_data` missing when creating DataRow.") + + if isinstance(item.get('row_data'), + str) and item.get('row_data').startswith("s3:/"): + raise InvalidQueryError( + "row_data: s3 assets must start with 'https'.") + allowed_extra_fields = { + 'attachments', 'media_type', 'dataset_id', 'embeddings' + } + invalid_keys = set(item) - {f.name for f in DataRow.fields() + } - allowed_extra_fields + if invalid_keys: + raise InvalidAttributeError(DataRow, invalid_keys) + return item + + def formatLegacyConversationalData(item): + messages = item.pop("conversationalData") + version = item.pop("version", 1) + type = item.pop("type", "application/vnd.labelbox.conversational") + if "externalId" in item: + external_id = item.pop("externalId") + item["external_id"] = external_id + if "globalKey" in item: + global_key = item.pop("globalKey") + item["globalKey"] = global_key + validate_conversational_data(messages) + one_conversation = \ + { + "type": type, + "version": version, + "messages": messages + } + item["row_data"] = one_conversation + return item + + def convert_item(data_row_item): + if isinstance(data_row_item, DataRowItemBase): + item = data_row_item.payload + else: + item = data_row_item + + if "tileLayerUrl" in item: + validate_attachments(item) + return item + + if "conversationalData" in item: + formatLegacyConversationalData(item) + + # Convert all payload variations into the same dict format + item = format_row(item) + # Make sure required keys exist (and there are no extra keys) + validate_keys(item) + # Make sure attachments are valid + validate_attachments(item) + # Make sure embeddings are valid + validate_embeddings(item) + # Parse metadata fields if they exist + parse_metadata_fields(item) + # Upload any local file paths + item = upload_if_necessary(item) + + if isinstance(data_row_item, DataRowItemBase): + return {'id': data_row_item.id, 'payload': item} + else: + return item + + if not isinstance(items, Iterable): + raise ValueError( + f"Must pass an iterable to create_data_rows. Found {type(items)}" + ) + + if len(items) > MAX_DATAROW_PER_API_OPERATION: + raise MalformedQueryException( + f"Cannot create more than {MAX_DATAROW_PER_API_OPERATION} DataRows per function call." + ) + + with ThreadPoolExecutor(file_upload_thread_count) as executor: + futures = [executor.submit(convert_item, item) for item in items] + items = [future.result() for future in as_completed(futures)] + # Prepare and upload the desciptor file + data = json.dumps(items) + return client.upload_data(data, + content_type="application/json", + filename="json_import.json") + + @staticmethod + def upload_in_chunks(client, specs: List[DataRowItemBase], + file_upload_thread_count: int, + upsert_chunk_size: int) -> UploadManifest: + empty_specs = list(filter(lambda spec: spec.is_empty(), specs)) + + if empty_specs: + ids = list(map(lambda spec: spec.id.get("value"), empty_specs)) + raise ValueError( + f"The following items have an empty payload: {ids}") + + chunks = [ + specs[i:i + upsert_chunk_size] + for i in range(0, len(specs), upsert_chunk_size) + ] + + def _upload_chunk(_chunk): + return DataRowUploader.create_descriptor_file(client, + _chunk, + is_upsert=True) + + with ThreadPoolExecutor(file_upload_thread_count) as executor: + futures = [ + executor.submit(_upload_chunk, chunk) for chunk in chunks + ] + chunk_uris = [future.result() for future in as_completed(futures)] + + return UploadManifest(source="SDK", + item_count=len(specs), + chunk_uris=chunk_uris) diff --git a/libs/labelbox/src/labelbox/schema/internal/datarow_upload_constants.py b/libs/labelbox/src/labelbox/schema/internal/datarow_upload_constants.py new file mode 100644 index 000000000..8d8c2c1f6 --- /dev/null +++ b/libs/labelbox/src/labelbox/schema/internal/datarow_upload_constants.py @@ -0,0 +1,3 @@ +MAX_DATAROW_PER_API_OPERATION = 150_000 +FILE_UPLOAD_THREAD_COUNT = 20 +UPSERT_CHUNK_SIZE = 10_000 diff --git a/libs/labelbox/src/labelbox/schema/task.py b/libs/labelbox/src/labelbox/schema/task.py index 977f7d4fa..16ebeca8a 100644 --- a/libs/labelbox/src/labelbox/schema/task.py +++ b/libs/labelbox/src/labelbox/schema/task.py @@ -2,13 +2,16 @@ import logging import requests import time -from typing import TYPE_CHECKING, Callable, Optional, Dict, Any, List, Union +from typing import TYPE_CHECKING, Callable, Optional, Dict, Any, List, Union, Final from labelbox import parser from labelbox.exceptions import ResourceNotFoundError from labelbox.orm.db_object import DbObject from labelbox.orm.model import Field, Relationship, Entity +from labelbox.pagination import PaginatedCollection +from labelbox.schema.internal.datarow_upload_constants import MAX_DATAROW_PER_API_OPERATION + if TYPE_CHECKING: from labelbox import User @@ -214,3 +217,77 @@ def get_task(client, task_id): task: Task = tasks[0] task._user = user return task + + +class DataUpsertTask(Task): + __max_donwload_size: Final = MAX_DATAROW_PER_API_OPERATION + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._user = None + + @property + def result(self) -> Union[List[Dict[str, Any]]]: + if self.status == "FAILED": + raise ValueError(f"Job failed. Errors : {self.errors}") + return self._result_as_list() + + @property + def created_data_rows(self) -> Optional[Dict[str, Any]]: + return self.result + + @property + def result_all(self) -> PaginatedCollection: + return self._download_result_paginated() + + def _download_result_paginated(self) -> PaginatedCollection: + page_size = 5000 # hardcode to avoid overloading the server + from_cursor = None + + query_str = """query SuccessesfulDataRowImportsPyApi($taskId: ID!, $first: Int, $from: String) { + successesfulDataRowImports(data: { taskId: $taskId, first: $first, from: $from}) + { + nodes { + id + externalId + globalKey + rowData + } + after + total + } + } + """ + + params = { + 'taskId': self.uid, + 'first': page_size, + 'from': from_cursor, + } + + return PaginatedCollection( + client=self.client, + query=query_str, + params=params, + dereferencing=['successesfulDataRowImports', 'nodes'], + obj_class=lambda _, data_row: { + 'id': data_row['id'], + 'external_id': data_row.get('externalId'), + 'row_data': data_row['rowData'], + 'global_key': data_row.get('globalKey'), + }, + cursor_path=['successesfulDataRowImports', 'after'], + ) + + def _result_as_list(self) -> List[Dict[str, Any]]: + total_downloaded = 0 + results = [] + data = self._download_result_paginated() + + for row in data: + results.append(row) + total_downloaded += 1 + if total_downloaded >= self.__max_donwload_size: + break + + return results diff --git a/libs/labelbox/tests/integration/test_data_rows.py b/libs/labelbox/tests/integration/test_data_rows.py index 672afe85d..427a8da17 100644 --- a/libs/labelbox/tests/integration/test_data_rows.py +++ b/libs/labelbox/tests/integration/test_data_rows.py @@ -2,12 +2,12 @@ import uuid from datetime import datetime import json +import requests -from labelbox.schema.media_type import MediaType - +from unittest.mock import patch import pytest -import requests +from labelbox.schema.media_type import MediaType from labelbox import DataRow, AssetAttachment from labelbox.exceptions import MalformedQueryException from labelbox.schema.task import Task @@ -173,19 +173,29 @@ def test_data_row_bulk_creation(dataset, rand_gen, image_url): client = dataset.client assert len(list(dataset.data_rows())) == 0 - # Test creation using URL - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url - }, - { - "row_data": image_url - }, - ]) - assert task in client.get_user().created_tasks() + with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', + new=1): # Force chunking + # Test creation using URL + task = dataset.create_data_rows([ + { + DataRow.row_data: image_url + }, + { + "row_data": image_url + }, + ]) task.wait_till_done() + assert task.has_errors() is False assert task.status == "COMPLETE" + results = task.result + assert len(results) == 2 + row_data = [result["row_data"] for result in results] + assert row_data == [image_url, image_url] + results_all = task.result_all + row_data = [result["row_data"] for result in results_all] + assert row_data == [image_url, image_url] + data_rows = list(dataset.data_rows()) assert len(data_rows) == 2 assert {data_row.row_data for data_row in data_rows} == {image_url} diff --git a/libs/labelbox/tests/integration/test_dataset.py b/libs/labelbox/tests/integration/test_dataset.py index de2f15820..8cd83a9bd 100644 --- a/libs/labelbox/tests/integration/test_dataset.py +++ b/libs/labelbox/tests/integration/test_dataset.py @@ -3,6 +3,7 @@ from labelbox import Dataset from labelbox.exceptions import ResourceNotFoundError, MalformedQueryException, InvalidQueryError from labelbox.schema.dataset import MAX_DATAROW_PER_API_OPERATION +from labelbox.schema.internal.datarow_uploader import DataRowUploader def test_dataset(client, rand_gen): @@ -153,7 +154,10 @@ def test_create_descriptor_file(dataset): with mock.patch.object(dataset.client, 'upload_data', wraps=dataset.client.upload_data) as upload_data_spy: - dataset._create_descriptor_file(items=[{'row_data': 'some text...'}]) + DataRowUploader.create_descriptor_file(dataset.client, + items=[{ + 'row_data': 'some text...' + }]) upload_data_spy.assert_called() call_args, call_kwargs = upload_data_spy.call_args_list[0][ 0], upload_data_spy.call_args_list[0][1] diff --git a/libs/labelbox/tests/unit/test_data_row_upsert_data.py b/libs/labelbox/tests/unit/test_data_row_upsert_data.py new file mode 100644 index 000000000..93683cfca --- /dev/null +++ b/libs/labelbox/tests/unit/test_data_row_upsert_data.py @@ -0,0 +1,64 @@ +import pytest +from labelbox.schema.internal.data_row_create_upsert import (DataRowUpsertItem, + DataRowCreateItem) +from labelbox.schema.identifiable import UniqueId, GlobalKey +from labelbox.schema.asset_attachment import AttachmentType + + +@pytest.fixture +def data_row_create_items(): + dataset_id = 'test_dataset' + items = [ + { + "row_data": "http://my_site.com/photos/img_01.jpg", + "global_key": "global_key1", + "external_id": "ex_id1", + "attachments": [ + {"type": AttachmentType.RAW_TEXT, "name": "att1", "value": "test1"} + ], + "metadata": [ + {"name": "tag", "value": "tag value"}, + ] + }, + ] + return dataset_id, items + + +@pytest.fixture +def data_row_update_items(): + dataset_id = 'test_dataset' + items = [ + { + "key": GlobalKey("global_key1"), + "global_key": "global_key1_updated" + }, + { + "key": UniqueId('unique_id1'), + "external_id": "ex_id1_updated" + }, + ] + return dataset_id, items + + +def test_data_row_upsert_items(data_row_create_items, data_row_update_items): + dataset_id, create_items = data_row_create_items + dataset_id, update_items = data_row_update_items + items = create_items + update_items + result = DataRowUpsertItem.build(dataset_id, items) + assert len(result) == len(items) + for item, res in zip(items, result): + assert res.payload == item + + +def test_data_row_create_items(data_row_create_items): + dataset_id, items = data_row_create_items + result = DataRowCreateItem.build(dataset_id, items) + assert len(result) == len(items) + for item, res in zip(items, result): + assert res.payload == item + + +def test_data_row_create_items_not_updateable(data_row_update_items): + dataset_id, items = data_row_update_items + with pytest.raises(ValueError): + DataRowCreateItem.build(dataset_id, items) From 3d9e52843c4a592ef0e9ecd61286390c0c793a3c Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Wed, 22 May 2024 11:01:26 -0700 Subject: [PATCH 2/9] Add support for task errors --- libs/labelbox/src/labelbox/schema/task.py | 89 +++++++++++++++++++++-- 1 file changed, 81 insertions(+), 8 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/task.py b/libs/labelbox/src/labelbox/schema/task.py index 16ebeca8a..12158fed1 100644 --- a/libs/labelbox/src/labelbox/schema/task.py +++ b/libs/labelbox/src/labelbox/schema/task.py @@ -230,18 +230,30 @@ def __init__(self, *args, **kwargs): def result(self) -> Union[List[Dict[str, Any]]]: if self.status == "FAILED": raise ValueError(f"Job failed. Errors : {self.errors}") - return self._result_as_list() + return self._results_as_list() + + @property + def errors(self) -> Optional[Dict[str, Any]]: + return self._errors_as_list() @property def created_data_rows(self) -> Optional[Dict[str, Any]]: return self.result + @property + def failed_data_rows(self) -> Optional[Dict[str, Any]]: + return self.errors + @property def result_all(self) -> PaginatedCollection: - return self._download_result_paginated() + return self._download_results_paginated() - def _download_result_paginated(self) -> PaginatedCollection: - page_size = 5000 # hardcode to avoid overloading the server + @property + def errors_all(self) -> PaginatedCollection: + return self._download_errors_paginated() + + def _download_results_paginated(self) -> PaginatedCollection: + page_size = 900 # hardcode to avoid overloading the server from_cursor = None query_str = """query SuccessesfulDataRowImportsPyApi($taskId: ID!, $first: Int, $from: String) { @@ -271,18 +283,66 @@ def _download_result_paginated(self) -> PaginatedCollection: params=params, dereferencing=['successesfulDataRowImports', 'nodes'], obj_class=lambda _, data_row: { - 'id': data_row['id'], + 'id': data_row.get('id'), 'external_id': data_row.get('externalId'), - 'row_data': data_row['rowData'], + 'row_data': data_row.get('rowData'), 'global_key': data_row.get('globalKey'), }, cursor_path=['successesfulDataRowImports', 'after'], ) - def _result_as_list(self) -> List[Dict[str, Any]]: + def _download_errors_paginated(self) -> PaginatedCollection: + page_size = 5000 # hardcode to avoid overloading the server + from_cursor = None + + query_str = """query FailedDataRowImportsPyApi($taskId: ID!, $first: Int, $from: String) { + failedDataRowImports(data: { taskId: $taskId, first: $first, from: $from}) + { + after + total + results { + message + spec { + externalId + globalKey + rowData + } + } + } + } + """ + + params = { + 'taskId': self.uid, + 'first': page_size, + 'from': from_cursor, + } + + return PaginatedCollection( + client=self.client, + query=query_str, + params=params, + dereferencing=['failedDataRowImports', 'results'], + obj_class=lambda _, data_row: { + 'error': + data_row.get('message'), + 'external_id': + data_row.get('spec').get('externalId') + if data_row.get('spec') else None, + 'row_data': + data_row.get('spec').get('rowData') + if data_row.get('spec') else None, + 'global_key': + data_row.get('spec').get('globalKey') + if data_row.get('spec') else None, + }, + cursor_path=['failedDataRowImports', 'after'], + ) + + def _results_as_list(self) -> List[Dict[str, Any]]: total_downloaded = 0 results = [] - data = self._download_result_paginated() + data = self._download_results_paginated() for row in data: results.append(row) @@ -291,3 +351,16 @@ def _result_as_list(self) -> List[Dict[str, Any]]: break return results + + def _errors_as_list(self) -> List[Dict[str, Any]]: + total_downloaded = 0 + errors = [] + data = self._download_errors_paginated() + + for row in data: + errors.append(row) + total_downloaded += 1 + if total_downloaded >= self.__max_donwload_size: + break + + return errors From ce942d837616fe513b47d124a59b9b171f59d375 Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Thu, 23 May 2024 09:19:57 -0700 Subject: [PATCH 3/9] Support for uploading from files --- libs/labelbox/src/labelbox/schema/dataset.py | 27 +++- .../schema/internal/data_row_uploader.py | 2 +- .../tests/integration/test_data_rows.py | 133 ++++++++++++------ 3 files changed, 116 insertions(+), 46 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/dataset.py b/libs/labelbox/src/labelbox/schema/dataset.py index 033c337e7..b1e835244 100644 --- a/libs/labelbox/src/labelbox/schema/dataset.py +++ b/libs/labelbox/src/labelbox/schema/dataset.py @@ -284,11 +284,35 @@ def create_data_rows( InvalidAttributeError: If there are fields in `items` not valid for a DataRow. ValueError: When the upload parameters are invalid - """ + NOTE dicts and strings items can not be mixed in the same call. It is a responsibility of the caller to ensure that all items are of the same type. + """ + if isinstance(items[0], str): + items = self._build_from_local_paths(items) # Assume list of file paths specs = DataRowCreateItem.build(self.uid, items) return self._exec_upsert_data_rows(specs, file_upload_thread_count) + def _build_from_local_paths( + self, + items: List[str], + file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT) -> List[dict]: + uploaded_items = [] + + def upload_file(item): + item_url = self.client.upload_file(item) + return {'row_data': item_url, 'external_id': item} + + with ThreadPoolExecutor(file_upload_thread_count) as executor: + futures = [ + executor.submit(upload_file, item) + for item in items + if isinstance(item, str) and os.path.exists(item) + ] + more_items = [future.result() for future in as_completed(futures)] + uploaded_items.extend(more_items) + + return uploaded_items + def data_rows_for_external_id(self, external_id, limit=10) -> List["DataRow"]: @@ -593,6 +617,7 @@ def _exec_upsert_data_rows( self, specs: List[DataRowItemBase], file_upload_thread_count: int = FILE_UPLOAD_THREAD_COUNT) -> "Task": + manifest = DataRowUploader.upload_in_chunks( client=self.client, specs=specs, diff --git a/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py b/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py index 907016ebb..57a6208b2 100644 --- a/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py +++ b/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py @@ -2,7 +2,7 @@ import os from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Iterable, List +from typing import Iterable, List, Union from labelbox.exceptions import InvalidQueryError from labelbox.exceptions import InvalidAttributeError diff --git a/libs/labelbox/tests/integration/test_data_rows.py b/libs/labelbox/tests/integration/test_data_rows.py index 427a8da17..f02f3816f 100644 --- a/libs/labelbox/tests/integration/test_data_rows.py +++ b/libs/labelbox/tests/integration/test_data_rows.py @@ -3,6 +3,7 @@ from datetime import datetime import json import requests +import os from unittest.mock import patch import pytest @@ -171,66 +172,110 @@ def test_lookup_data_rows(client, dataset): def test_data_row_bulk_creation(dataset, rand_gen, image_url): client = dataset.client + data_rows = [] assert len(list(dataset.data_rows())) == 0 - with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', - new=1): # Force chunking - # Test creation using URL - task = dataset.create_data_rows([ - { - DataRow.row_data: image_url - }, - { - "row_data": image_url - }, - ]) - task.wait_till_done() - assert task.has_errors() is False - assert task.status == "COMPLETE" + try: + with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', + new=1): # Force chunking + # Test creation using URL + task = dataset.create_data_rows([ + { + DataRow.row_data: image_url + }, + { + "row_data": image_url + }, + ]) + task.wait_till_done() + assert task.has_errors() is False + assert task.status == "COMPLETE" - results = task.result - assert len(results) == 2 - row_data = [result["row_data"] for result in results] - assert row_data == [image_url, image_url] - results_all = task.result_all - row_data = [result["row_data"] for result in results_all] - assert row_data == [image_url, image_url] + results = task.result + assert len(results) == 2 + row_data = [result["row_data"] for result in results] + assert row_data == [image_url, image_url] + results_all = task.result_all + row_data = [result["row_data"] for result in results_all] + assert row_data == [image_url, image_url] - data_rows = list(dataset.data_rows()) - assert len(data_rows) == 2 - assert {data_row.row_data for data_row in data_rows} == {image_url} - assert {data_row.global_key for data_row in data_rows} == {None} + data_rows = list(dataset.data_rows()) + assert len(data_rows) == 2 + assert {data_row.row_data for data_row in data_rows} == {image_url} + assert {data_row.global_key for data_row in data_rows} == {None} - data_rows = list(dataset.data_rows(from_cursor=data_rows[0].uid)) - assert len(data_rows) == 1 + data_rows = list(dataset.data_rows(from_cursor=data_rows[0].uid)) + assert len(data_rows) == 1 - # Test creation using file name - with NamedTemporaryFile() as fp: - data = rand_gen(str).encode() - fp.write(data) - fp.flush() - task = dataset.create_data_rows([fp.name]) + finally: + for dr in data_rows: + dr.delete() + + +@pytest.fixture +def local_image_file(image_url) -> NamedTemporaryFile: + response = requests.get(image_url, stream=True) + response.raise_for_status() + + with NamedTemporaryFile(delete=False) as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + + yield f # Return the path to the temp file + + os.remove(f.name) + + +def test_data_row_bulk_creation_from_file(dataset, local_image_file, image_url): + with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', + new=1): # Force chunking + task = dataset.create_data_rows( + [local_image_file.name, local_image_file.name]) task.wait_till_done() assert task.status == "COMPLETE" + assert len(task.result) == 2 + assert task.has_errors() is False + results = [r for r in task.result_all] + row_data = [result["row_data"] for result in results] + assert row_data == [image_url, image_url] + +def test_data_row_bulk_creation_from_row_data_file_external_id( + dataset, local_image_file, image_url): + with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', + new=1): # Force chunking task = dataset.create_data_rows([{ - "row_data": fp.name, + "row_data": local_image_file.name, 'external_id': 'some_name' + }, { + "row_data": image_url, + 'external_id': 'some_name2' }]) - task.wait_till_done() assert task.status == "COMPLETE" + assert len(task.result) == 2 + assert task.has_errors() is False + results = [r for r in task.result_all] + row_data = [result["row_data"] for result in results] + assert row_data == [image_url, image_url] + - task = dataset.create_data_rows([{"row_data": fp.name}]) +def test_data_row_bulk_creation_from_row_data_file(dataset, rand_gen, + local_image_file, image_url): + with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', + new=1): # Force chunking + task = dataset.create_data_rows([{ + "row_data": local_image_file.name + }, { + "row_data": local_image_file.name + }]) task.wait_till_done() assert task.status == "COMPLETE" - - data_rows = list(dataset.data_rows()) - assert len(data_rows) == 5 - url = ({data_row.row_data for data_row in data_rows} - {image_url}).pop() - assert requests.get(url).content == data - - for dr in data_rows: - dr.delete() + assert len(task.result) == 2 + assert task.has_errors() is False + results = [r for r in task.result_all] + row_data = [result["row_data"] for result in results] + assert row_data == [image_url, image_url] @pytest.mark.slow From cefbedeeff68f07b0f0c16123043366d2bcf528e Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Thu, 23 May 2024 09:20:51 -0700 Subject: [PATCH 4/9] Fixing tests --- libs/labelbox/src/labelbox/schema/dataset.py | 11 ++- .../schema/internal/data_row_create_upsert.py | 21 ++++-- .../internal/datarow_upload_constants.py | 1 + libs/labelbox/src/labelbox/schema/task.py | 71 +++++++++++++------ .../tests/integration/test_data_rows.py | 21 ++++-- .../integration/test_data_rows_upsert.py | 5 +- .../tests/integration/test_dataset.py | 11 +-- libs/labelbox/tests/integration/test_task.py | 5 +- 8 files changed, 94 insertions(+), 52 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/dataset.py b/libs/labelbox/src/labelbox/schema/dataset.py index b1e835244..9e230340d 100644 --- a/libs/labelbox/src/labelbox/schema/dataset.py +++ b/libs/labelbox/src/labelbox/schema/dataset.py @@ -287,9 +287,14 @@ def create_data_rows( NOTE dicts and strings items can not be mixed in the same call. It is a responsibility of the caller to ensure that all items are of the same type. """ - if isinstance(items[0], str): - items = self._build_from_local_paths(items) # Assume list of file paths - specs = DataRowCreateItem.build(self.uid, items) + string_items = [item for item in items if isinstance(item, str)] + dict_items = [item for item in items if isinstance(item, dict)] + dict_string_items = [] + + if len(string_items) > 0: + dict_string_items = self._build_from_local_paths(string_items) + specs = DataRowCreateItem.build(self.uid, + dict_items + dict_string_items) return self._exec_upsert_data_rows(specs, file_upload_thread_count) def _build_from_local_paths( diff --git a/libs/labelbox/src/labelbox/schema/internal/data_row_create_upsert.py b/libs/labelbox/src/labelbox/schema/internal/data_row_create_upsert.py index 70656dbf5..fbbc02b70 100644 --- a/libs/labelbox/src/labelbox/schema/internal/data_row_create_upsert.py +++ b/libs/labelbox/src/labelbox/schema/internal/data_row_create_upsert.py @@ -25,13 +25,12 @@ def build( key = item.pop('key', None) if not key: key = {'type': 'AUTO', 'value': ''} - elif isinstance(key, key_types): + elif isinstance(key, key_types): # type: ignore key = {'type': key.id_type.value, 'value': key.key} else: if not key_types: raise ValueError( - f"Can not have a key for this item, got: {key}" - ) + f"Can not have a key for this item, got: {key}") raise ValueError( f"Key must be an instance of {', '.join([t.__name__ for t in key_types])}, got: {type(item['key']).__name__}" ) @@ -53,14 +52,22 @@ def is_empty(self) -> bool: class DataRowUpsertItem(DataRowItemBase): @classmethod - def build(cls, dataset_id: str, - items: List[dict]) -> List["DataRowUpsertItem"]: + def build( + cls, + dataset_id: str, + items: List[dict], + key_types: Optional[Tuple[type, ...]] = () + ) -> List["DataRowItemBase"]: return super().build(dataset_id, items, (UniqueId, GlobalKey)) class DataRowCreateItem(DataRowItemBase): @classmethod - def build(cls, dataset_id: str, - items: List[dict]) -> List["DataRowCreateItem"]: + def build( + cls, + dataset_id: str, + items: List[dict], + key_types: Optional[Tuple[type, ...]] = () + ) -> List["DataRowItemBase"]: return super().build(dataset_id, items, ()) diff --git a/libs/labelbox/src/labelbox/schema/internal/datarow_upload_constants.py b/libs/labelbox/src/labelbox/schema/internal/datarow_upload_constants.py index 8d8c2c1f6..f4c919095 100644 --- a/libs/labelbox/src/labelbox/schema/internal/datarow_upload_constants.py +++ b/libs/labelbox/src/labelbox/schema/internal/datarow_upload_constants.py @@ -1,3 +1,4 @@ MAX_DATAROW_PER_API_OPERATION = 150_000 FILE_UPLOAD_THREAD_COUNT = 20 UPSERT_CHUNK_SIZE = 10_000 +DOWNLOAD_RESULT_PAGE_SIZE = 5_000 diff --git a/libs/labelbox/src/labelbox/schema/task.py b/libs/labelbox/src/labelbox/schema/task.py index 12158fed1..38859e262 100644 --- a/libs/labelbox/src/labelbox/schema/task.py +++ b/libs/labelbox/src/labelbox/schema/task.py @@ -10,7 +10,10 @@ from labelbox.orm.model import Field, Relationship, Entity from labelbox.pagination import PaginatedCollection -from labelbox.schema.internal.datarow_upload_constants import MAX_DATAROW_PER_API_OPERATION +from labelbox.schema.internal.datarow_upload_constants import ( + MAX_DATAROW_PER_API_OPERATION, + DOWNLOAD_RESULT_PAGE_SIZE, +) if TYPE_CHECKING: from labelbox import User @@ -52,6 +55,10 @@ class Task(DbObject): created_by = Relationship.ToOne("User", False, "created_by") organization = Relationship.ToOne("Organization") + def __eq__(self, task): + return isinstance( + task, Task) and task.uid == self.uid and task.type == self.type + # Import and upsert have several instances of special casing def is_creation_task(self) -> bool: return self.name == 'JSON Import' or self.type == 'adv-upsert-data-rows' @@ -227,21 +234,23 @@ def __init__(self, *args, **kwargs): self._user = None @property - def result(self) -> Union[List[Dict[str, Any]]]: + def result(self) -> Optional[List[Dict[str, Any]]]: # type: ignore if self.status == "FAILED": raise ValueError(f"Job failed. Errors : {self.errors}") return self._results_as_list() @property - def errors(self) -> Optional[Dict[str, Any]]: + def errors(self) -> Optional[List[Dict[str, Any]]]: # type: ignore return self._errors_as_list() @property - def created_data_rows(self) -> Optional[Dict[str, Any]]: + def created_data_rows( # type: ignore + self) -> Optional[List[Dict[str, Any]]]: return self.result @property - def failed_data_rows(self) -> Optional[Dict[str, Any]]: + def failed_data_rows( # type: ignore + self) -> Optional[List[Dict[str, Any]]]: return self.errors @property @@ -253,7 +262,7 @@ def errors_all(self) -> PaginatedCollection: return self._download_errors_paginated() def _download_results_paginated(self) -> PaginatedCollection: - page_size = 900 # hardcode to avoid overloading the server + page_size = DOWNLOAD_RESULT_PAGE_SIZE from_cursor = None query_str = """query SuccessesfulDataRowImportsPyApi($taskId: ID!, $first: Int, $from: String) { @@ -292,7 +301,7 @@ def _download_results_paginated(self) -> PaginatedCollection: ) def _download_errors_paginated(self) -> PaginatedCollection: - page_size = 5000 # hardcode to avoid overloading the server + page_size = DOWNLOAD_RESULT_PAGE_SIZE # hardcode to avoid overloading the server from_cursor = None query_str = """query FailedDataRowImportsPyApi($taskId: ID!, $first: Int, $from: String) { @@ -306,6 +315,16 @@ def _download_errors_paginated(self) -> PaginatedCollection: externalId globalKey rowData + metadata { + schemaId + value + name + } + attachments { + type + value + name + } } } } @@ -318,28 +337,30 @@ def _download_errors_paginated(self) -> PaginatedCollection: 'from': from_cursor, } + def convert_errors_to_legacy_format(client, data_row): + spec = data_row.get('spec', {}) + return { + 'message': + data_row.get('message'), + 'failedDataRows': [{ + 'externalId': spec.get('externalId'), + 'rowData': spec.get('rowData'), + 'globalKey': spec.get('globalKey'), + 'metadata': spec.get('metadata', []), + 'attachments': spec.get('attachments', []), + }] + } + return PaginatedCollection( client=self.client, query=query_str, params=params, dereferencing=['failedDataRowImports', 'results'], - obj_class=lambda _, data_row: { - 'error': - data_row.get('message'), - 'external_id': - data_row.get('spec').get('externalId') - if data_row.get('spec') else None, - 'row_data': - data_row.get('spec').get('rowData') - if data_row.get('spec') else None, - 'global_key': - data_row.get('spec').get('globalKey') - if data_row.get('spec') else None, - }, + obj_class=convert_errors_to_legacy_format, cursor_path=['failedDataRowImports', 'after'], ) - def _results_as_list(self) -> List[Dict[str, Any]]: + def _results_as_list(self) -> Optional[List[Dict[str, Any]]]: total_downloaded = 0 results = [] data = self._download_results_paginated() @@ -350,9 +371,12 @@ def _results_as_list(self) -> List[Dict[str, Any]]: if total_downloaded >= self.__max_donwload_size: break + if len(results) == 0: + return None + return results - def _errors_as_list(self) -> List[Dict[str, Any]]: + def _errors_as_list(self) -> Optional[List[Dict[str, Any]]]: total_downloaded = 0 errors = [] data = self._download_errors_paginated() @@ -363,4 +387,7 @@ def _errors_as_list(self) -> List[Dict[str, Any]]: if total_downloaded >= self.__max_donwload_size: break + if len(errors) == 0: + return None + return errors diff --git a/libs/labelbox/tests/integration/test_data_rows.py b/libs/labelbox/tests/integration/test_data_rows.py index f02f3816f..c1212460f 100644 --- a/libs/labelbox/tests/integration/test_data_rows.py +++ b/libs/labelbox/tests/integration/test_data_rows.py @@ -238,7 +238,7 @@ def test_data_row_bulk_creation_from_file(dataset, local_image_file, image_url): assert task.has_errors() is False results = [r for r in task.result_all] row_data = [result["row_data"] for result in results] - assert row_data == [image_url, image_url] + assert len(row_data) == 2 def test_data_row_bulk_creation_from_row_data_file_external_id( @@ -252,12 +252,14 @@ def test_data_row_bulk_creation_from_row_data_file_external_id( "row_data": image_url, 'external_id': 'some_name2' }]) + task.wait_till_done() assert task.status == "COMPLETE" assert len(task.result) == 2 assert task.has_errors() is False results = [r for r in task.result_all] row_data = [result["row_data"] for result in results] - assert row_data == [image_url, image_url] + assert len(row_data) == 2 + assert image_url in row_data def test_data_row_bulk_creation_from_row_data_file(dataset, rand_gen, @@ -275,7 +277,7 @@ def test_data_row_bulk_creation_from_row_data_file(dataset, rand_gen, assert task.has_errors() is False results = [r for r in task.result_all] row_data = [result["row_data"] for result in results] - assert row_data == [image_url, image_url] + assert len(row_data) == 2 @pytest.mark.slow @@ -899,6 +901,7 @@ def test_create_data_rows_result(client, dataset, image_url): DataRow.external_id: "row1", }, ]) + task.wait_till_done() assert task.errors is None for result in task.result: client.get_data_row(result['id']) @@ -973,8 +976,16 @@ def test_data_row_bulk_creation_with_same_global_keys(dataset, sample_image, 'message'] == f"Duplicate global key: '{global_key_1}'" assert task.failed_data_rows[0]['failedDataRows'][0][ 'externalId'] == sample_image - assert task.created_data_rows[0]['externalId'] == sample_image - assert task.created_data_rows[0]['globalKey'] == global_key_1 + assert task.created_data_rows[0]['external_id'] == sample_image + assert task.created_data_rows[0]['global_key'] == global_key_1 + + errors = task.errors_all + all_errors = [er for er in errors] + assert len(all_errors) == 1 + assert task.has_errors() is True + + all_results = [result for result in task.result_all] + assert len(all_results) == 1 def test_data_row_delete_and_create_with_same_global_key( diff --git a/libs/labelbox/tests/integration/test_data_rows_upsert.py b/libs/labelbox/tests/integration/test_data_rows_upsert.py index c73ae5e5c..accde6dd7 100644 --- a/libs/labelbox/tests/integration/test_data_rows_upsert.py +++ b/libs/labelbox/tests/integration/test_data_rows_upsert.py @@ -208,9 +208,8 @@ def test_multiple_chunks(self, client, dataset, image_url): mocked_chunk_size = 3 with patch('labelbox.client.Client.upload_data', wraps=client.upload_data) as spy_some_function: - with patch( - 'labelbox.schema.dataset.Dataset._Dataset__upsert_chunk_size', - new=mocked_chunk_size): + with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', + new=mocked_chunk_size): task = dataset.upsert_data_rows([{ 'row_data': image_url } for i in range(10)]) diff --git a/libs/labelbox/tests/integration/test_dataset.py b/libs/labelbox/tests/integration/test_dataset.py index 8cd83a9bd..82378bee1 100644 --- a/libs/labelbox/tests/integration/test_dataset.py +++ b/libs/labelbox/tests/integration/test_dataset.py @@ -3,7 +3,7 @@ from labelbox import Dataset from labelbox.exceptions import ResourceNotFoundError, MalformedQueryException, InvalidQueryError from labelbox.schema.dataset import MAX_DATAROW_PER_API_OPERATION -from labelbox.schema.internal.datarow_uploader import DataRowUploader +from labelbox.schema.internal.data_row_uploader import DataRowUploader def test_dataset(client, rand_gen): @@ -166,12 +166,3 @@ def test_create_descriptor_file(dataset): 'content_type': 'application/json', 'filename': 'json_import.json' } - - -def test_max_dataset_datarow_upload(dataset, image_url, rand_gen): - external_id = str(rand_gen) - items = [dict(row_data=image_url, external_id=external_id) - ] * (MAX_DATAROW_PER_API_OPERATION + 1) - - with pytest.raises(MalformedQueryException): - dataset.create_data_rows(items) diff --git a/libs/labelbox/tests/integration/test_task.py b/libs/labelbox/tests/integration/test_task.py index 66b34d456..3988bf107 100644 --- a/libs/labelbox/tests/integration/test_task.py +++ b/libs/labelbox/tests/integration/test_task.py @@ -61,11 +61,12 @@ def test_task_success_json(dataset, image_url, snapshot): @pytest.mark.export_v1("export_v1 test remove later") def test_task_success_label_export(client, configured_project_with_label): project, _, _, _ = configured_project_with_label - project.export_labels() + # TODO: Move to export_v2 + res = project.export_labels() user = client.get_user() task = None for task in user.created_tasks(): - if task.name != 'JSON Import': + if task.name != 'JSON Import' and task.type != 'adv-upsert-data-rows': break with pytest.raises(ValueError) as exc_info: From 986078a81fdf42643be32fe5dc741c64086ba804 Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Mon, 27 May 2024 21:01:28 -0700 Subject: [PATCH 5/9] Add hash function to task --- libs/labelbox/src/labelbox/schema/task.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/libs/labelbox/src/labelbox/schema/task.py b/libs/labelbox/src/labelbox/schema/task.py index 38859e262..92959da6a 100644 --- a/libs/labelbox/src/labelbox/schema/task.py +++ b/libs/labelbox/src/labelbox/schema/task.py @@ -59,6 +59,9 @@ def __eq__(self, task): return isinstance( task, Task) and task.uid == self.uid and task.type == self.type + def __hash__(self): + return hash(self.uid) + # Import and upsert have several instances of special casing def is_creation_task(self) -> bool: return self.name == 'JSON Import' or self.type == 'adv-upsert-data-rows' From d4fc6957695f91335e18337014c689ea861aa47e Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Wed, 29 May 2024 11:13:18 -0700 Subject: [PATCH 6/9] Remove 150K data rows limitation --- libs/labelbox/src/labelbox/schema/dataset.py | 28 ++++++++----------- .../data/annotation_import/test_model_run.py | 1 + libs/labelbox/tests/integration/test_task.py | 3 +- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/dataset.py b/libs/labelbox/src/labelbox/schema/dataset.py index 9e230340d..2371daef8 100644 --- a/libs/labelbox/src/labelbox/schema/dataset.py +++ b/libs/labelbox/src/labelbox/schema/dataset.py @@ -260,10 +260,10 @@ def create_data_rows_sync(self, items) -> None: url_param: descriptor_url }) - def create_data_rows( - self, - items, - file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT) -> "Task": + def create_data_rows(self, + items, + file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT + ) -> "DataUpsertTask": """ Asynchronously bulk upload data rows Use this instead of `Dataset.create_data_rows_sync` uploads for batches that contain more than 1000 data rows. @@ -576,10 +576,10 @@ def _export( is_streamable = res["isStreamable"] return Task.get_task(self.client, task_id), is_streamable - def upsert_data_rows( - self, - items, - file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT) -> "Task": + def upsert_data_rows(self, + items, + file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT + ) -> "DataUpsertTask": """ Upserts data rows in this dataset. When "key" is provided, and it references an existing data row, an update will be performed. When "key" is not provided a new data row will be created. @@ -610,18 +610,14 @@ def upsert_data_rows( >>> ]) >>> task.wait_till_done() """ - if len(items) > MAX_DATAROW_PER_API_OPERATION: - raise MalformedQueryException( - f"Cannot upsert more than {MAX_DATAROW_PER_API_OPERATION} DataRows per function call." - ) - specs = DataRowUpsertItem.build(self.uid, items) return self._exec_upsert_data_rows(specs, file_upload_thread_count) def _exec_upsert_data_rows( - self, - specs: List[DataRowItemBase], - file_upload_thread_count: int = FILE_UPLOAD_THREAD_COUNT) -> "Task": + self, + specs: List[DataRowItemBase], + file_upload_thread_count: int = FILE_UPLOAD_THREAD_COUNT + ) -> "DataUpsertTask": manifest = DataRowUploader.upload_in_chunks( client=self.client, diff --git a/libs/labelbox/tests/data/annotation_import/test_model_run.py b/libs/labelbox/tests/data/annotation_import/test_model_run.py index 328b38ba5..3b2e04e62 100644 --- a/libs/labelbox/tests/data/annotation_import/test_model_run.py +++ b/libs/labelbox/tests/data/annotation_import/test_model_run.py @@ -174,6 +174,7 @@ def test_model_run_split_assignment_by_data_row_ids( data_rows = dataset.create_data_rows([{ "row_data": image_url } for _ in range(n_data_rows)]) + data_rows.wait_till_done() data_row_ids = [data_row['id'] for data_row in data_rows.result] configured_project_with_one_data_row._wait_until_data_rows_are_processed( data_row_ids=data_row_ids) diff --git a/libs/labelbox/tests/integration/test_task.py b/libs/labelbox/tests/integration/test_task.py index 3988bf107..07309928a 100644 --- a/libs/labelbox/tests/integration/test_task.py +++ b/libs/labelbox/tests/integration/test_task.py @@ -58,11 +58,12 @@ def test_task_success_json(dataset, image_url, snapshot): 'test_task.test_task_success_json.json') assert len(task.result) + @pytest.mark.export_v1("export_v1 test remove later") def test_task_success_label_export(client, configured_project_with_label): project, _, _, _ = configured_project_with_label # TODO: Move to export_v2 - res = project.export_labels() + project.export_labels() user = client.get_user() task = None for task in user.created_tasks(): From 79d3d0c9ef4ef08b74fd7a765c17a64f58502fde Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Wed, 29 May 2024 11:58:08 -0700 Subject: [PATCH 7/9] Add docstrings --- libs/labelbox/src/labelbox/schema/dataset.py | 43 +++++++++++-------- .../schema/internal/data_row_create_upsert.py | 3 ++ libs/labelbox/src/labelbox/schema/task.py | 18 ++++++++ 3 files changed, 47 insertions(+), 17 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/dataset.py b/libs/labelbox/src/labelbox/schema/dataset.py index 2371daef8..18514767c 100644 --- a/libs/labelbox/src/labelbox/schema/dataset.py +++ b/libs/labelbox/src/labelbox/schema/dataset.py @@ -31,16 +31,13 @@ from labelbox.schema.identifiable import UniqueId, GlobalKey from labelbox.schema.task import Task, DataUpsertTask from labelbox.schema.user import User -<<<<<<< HEAD from labelbox.schema.iam_integration import IAMIntegration -======= from labelbox.schema.internal.data_row_create_upsert import (DataRowItemBase, DataRowUpsertItem, DataRowCreateItem) from labelbox.schema.internal.data_row_uploader import DataRowUploader from labelbox.schema.internal.datarow_upload_constants import ( MAX_DATAROW_PER_API_OPERATION, FILE_UPLOAD_THREAD_COUNT, UPSERT_CHUNK_SIZE) ->>>>>>> 58e48ccf (Refactor upsert code so that it can be reused for create) logger = logging.getLogger(__name__) @@ -644,11 +641,13 @@ def _exec_upsert_data_rows( task._user = self.client.get_user() return task - def add_iam_integration(self, iam_integration: Union[str, IAMIntegration]) -> IAMIntegration: + def add_iam_integration( + self, iam_integration: Union[str, + IAMIntegration]) -> IAMIntegration: """ Sets the IAM integration for the dataset. IAM integration is used to sign URLs for data row assets. - Args: + Args: iam_integration (Union[str, IAMIntegration]): IAM integration object or IAM integration id. Returns: @@ -679,7 +678,8 @@ def add_iam_integration(self, iam_integration: Union[str, IAMIntegration]) -> IA >>> dataset.set_iam_integration(iam_integration) """ - iam_integration_id = iam_integration.uid if isinstance(iam_integration, IAMIntegration) else iam_integration + iam_integration_id = iam_integration.uid if isinstance( + iam_integration, IAMIntegration) else iam_integration query = """ mutation SetSignerForDatasetPyApi($signerId: ID!, $datasetId: ID!) { @@ -695,20 +695,30 @@ def add_iam_integration(self, iam_integration: Union[str, IAMIntegration]) -> IA } """ - response = self.client.execute(query, {"signerId": iam_integration_id, "datasetId": self.uid}) + response = self.client.execute(query, { + "signerId": iam_integration_id, + "datasetId": self.uid + }) if not response: - raise ResourceNotFoundError(IAMIntegration, {"signerId": iam_integration_id, "datasetId": self.uid}) - - try: - iam_integration_id = response.get("setSignerForDataset", {}).get("signer", {})["id"] + raise ResourceNotFoundError(IAMIntegration, { + "signerId": iam_integration_id, + "datasetId": self.uid + }) - return [integration for integration - in self.client.get_organization().get_iam_integrations() - if integration.uid == iam_integration_id][0] + try: + iam_integration_id = response.get("setSignerForDataset", + {}).get("signer", {})["id"] + + return [ + integration for integration in + self.client.get_organization().get_iam_integrations() + if integration.uid == iam_integration_id + ][0] except: - raise LabelboxError(f"Can't retrieve IAM integration {iam_integration_id}") - + raise LabelboxError( + f"Can't retrieve IAM integration {iam_integration_id}") + def remove_iam_integration(self) -> None: """ Unsets the IAM integration for the dataset. @@ -738,4 +748,3 @@ def remove_iam_integration(self) -> None: if not response: raise ResourceNotFoundError(Dataset, {"id": self.uid}) - diff --git a/libs/labelbox/src/labelbox/schema/internal/data_row_create_upsert.py b/libs/labelbox/src/labelbox/schema/internal/data_row_create_upsert.py index fbbc02b70..4f5b5086e 100644 --- a/libs/labelbox/src/labelbox/schema/internal/data_row_create_upsert.py +++ b/libs/labelbox/src/labelbox/schema/internal/data_row_create_upsert.py @@ -6,6 +6,9 @@ class DataRowItemBase(BaseModel, ABC): + """ + Base class for creating payloads for upsert operations. + """ id: dict payload: dict diff --git a/libs/labelbox/src/labelbox/schema/task.py b/libs/labelbox/src/labelbox/schema/task.py index 92959da6a..6c7eaf019 100644 --- a/libs/labelbox/src/labelbox/schema/task.py +++ b/libs/labelbox/src/labelbox/schema/task.py @@ -230,6 +230,9 @@ def get_task(client, task_id): class DataUpsertTask(Task): + """ + Task class for data row upsert operations + """ __max_donwload_size: Final = MAX_DATAROW_PER_API_OPERATION def __init__(self, *args, **kwargs): @@ -238,12 +241,18 @@ def __init__(self, *args, **kwargs): @property def result(self) -> Optional[List[Dict[str, Any]]]: # type: ignore + """ + Fetches maximum 150K results. If you need to fetch more, use `result_all` property + """ if self.status == "FAILED": raise ValueError(f"Job failed. Errors : {self.errors}") return self._results_as_list() @property def errors(self) -> Optional[List[Dict[str, Any]]]: # type: ignore + """ + Fetches maximum 150K errors. If you need to fetch more, use `errors_all` property + """ return self._errors_as_list() @property @@ -258,10 +267,19 @@ def failed_data_rows( # type: ignore @property def result_all(self) -> PaginatedCollection: + """ + This method uses our standard PaginatedCollection and allow to fetch any number of results + See here for more https://docs.labelbox.com/reference/sdk-fundamental-concepts-1#iterate-over-paginatedcollection + """ return self._download_results_paginated() @property def errors_all(self) -> PaginatedCollection: + """ + This method uses our standard PaginatedCollection and allow to fetch any number of errors + See here for more https://docs.labelbox.com/reference/sdk-fundamental-concepts-1#iterate-over-paginatedcollection + """ + return self._download_errors_paginated() def _download_results_paginated(self) -> PaginatedCollection: From 6ba61bb3beab6c43d4824d6ea7b6906d3bac86c3 Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Wed, 29 May 2024 16:32:31 -0700 Subject: [PATCH 8/9] PR feedback --- libs/labelbox/src/labelbox/schema/dataset.py | 17 ++++--- .../schema/internal/data_row_uploader.py | 45 ++++++++----------- ...eate_upsert.py => data_row_upsert_item.py} | 30 +------------ libs/labelbox/src/labelbox/schema/task.py | 6 +-- .../tests/integration/test_data_rows.py | 3 +- .../tests/unit/test_data_row_upsert_data.py | 24 +++++----- 6 files changed, 49 insertions(+), 76 deletions(-) rename libs/labelbox/src/labelbox/schema/internal/{data_row_create_upsert.py => data_row_upsert_item.py} (70%) diff --git a/libs/labelbox/src/labelbox/schema/dataset.py b/libs/labelbox/src/labelbox/schema/dataset.py index 18514767c..4f74e9b7f 100644 --- a/libs/labelbox/src/labelbox/schema/dataset.py +++ b/libs/labelbox/src/labelbox/schema/dataset.py @@ -32,9 +32,7 @@ from labelbox.schema.task import Task, DataUpsertTask from labelbox.schema.user import User from labelbox.schema.iam_integration import IAMIntegration -from labelbox.schema.internal.data_row_create_upsert import (DataRowItemBase, - DataRowUpsertItem, - DataRowCreateItem) +from labelbox.schema.internal.data_row_upsert_item import (DataRowUpsertItem) from labelbox.schema.internal.data_row_uploader import DataRowUploader from labelbox.schema.internal.datarow_upload_constants import ( MAX_DATAROW_PER_API_OPERATION, FILE_UPLOAD_THREAD_COUNT, UPSERT_CHUNK_SIZE) @@ -284,13 +282,18 @@ def create_data_rows(self, NOTE dicts and strings items can not be mixed in the same call. It is a responsibility of the caller to ensure that all items are of the same type. """ + + if file_upload_thread_count < 1: + raise ValueError( + "file_upload_thread_count must be a positive integer") + string_items = [item for item in items if isinstance(item, str)] dict_items = [item for item in items if isinstance(item, dict)] dict_string_items = [] if len(string_items) > 0: dict_string_items = self._build_from_local_paths(string_items) - specs = DataRowCreateItem.build(self.uid, + specs = DataRowUpsertItem.build(self.uid, dict_items + dict_string_items) return self._exec_upsert_data_rows(specs, file_upload_thread_count) @@ -607,12 +610,12 @@ def upsert_data_rows(self, >>> ]) >>> task.wait_till_done() """ - specs = DataRowUpsertItem.build(self.uid, items) + specs = DataRowUpsertItem.build(self.uid, items, (UniqueId, GlobalKey)) return self._exec_upsert_data_rows(specs, file_upload_thread_count) def _exec_upsert_data_rows( self, - specs: List[DataRowItemBase], + specs: List[DataRowUpsertItem], file_upload_thread_count: int = FILE_UPLOAD_THREAD_COUNT ) -> "DataUpsertTask": @@ -622,7 +625,7 @@ def _exec_upsert_data_rows( file_upload_thread_count=file_upload_thread_count, upsert_chunk_size=UPSERT_CHUNK_SIZE) - data = json.dumps(manifest.to_dict()).encode("utf-8") + data = json.dumps(manifest.dict()).encode("utf-8") manifest_uri = self.client.upload_data(data, content_type="application/json", filename="manifest.json") diff --git a/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py b/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py index 57a6208b2..9be4e2ffd 100644 --- a/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py +++ b/libs/labelbox/src/labelbox/schema/internal/data_row_uploader.py @@ -2,7 +2,7 @@ import os from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Iterable, List, Union +from typing import Iterable, List from labelbox.exceptions import InvalidQueryError from labelbox.exceptions import InvalidAttributeError @@ -10,23 +10,16 @@ from labelbox.orm.model import Entity from labelbox.orm.model import Field from labelbox.schema.embedding import EmbeddingVector -from labelbox.schema.internal.data_row_create_upsert import DataRowItemBase -from labelbox.schema.internal.datarow_upload_constants import MAX_DATAROW_PER_API_OPERATION +from labelbox.pydantic_compat import BaseModel +from labelbox.schema.internal.datarow_upload_constants import ( + MAX_DATAROW_PER_API_OPERATION, FILE_UPLOAD_THREAD_COUNT) +from labelbox.schema.internal.data_row_upsert_item import DataRowUpsertItem -class UploadManifest: - - def __init__(self, source: str, item_count: int, chunk_uris: List[str]): - self.source = source - self.item_count = item_count - self.chunk_uris = chunk_uris - - def to_dict(self): - return { - "source": self.source, - "item_count": self.item_count, - "chunk_uris": self.chunk_uris - } +class UploadManifest(BaseModel): + source: str + item_count: int + chunk_uris: List[str] class DataRowUploader: @@ -39,7 +32,7 @@ def create_descriptor_file(client, """ This function is shared by `Dataset.create_data_rows`, `Dataset.create_data_rows_sync` and `Dataset.update_data_rows`. It is used to prepare the input file. The user defined input is validated, processed, and json stringified. - Finally the json data is uploaded to gcs and a uri is returned. This uri can be passed to + Finally the json data is uploaded to gcs and a uri is returned. This uri can be passed as a parameter to a mutation that uploads data rows Each element in `items` can be either a `str` or a `dict`. If it is a `str`, then it is interpreted as a local file path. The file @@ -47,7 +40,7 @@ def create_descriptor_file(client, If an item is a `dict`, then it could support one of the two following structures 1. For static imagery, video, and text it should map `DataRow` field names to values. - At the minimum an `item` passed as a `dict` must contain a `row_data` key and value. + At the minimum an `items` passed as a `dict` must contain a `row_data` key and value. If the value for row_data is a local file path and the path exists, then the local file will be uploaded to labelbox. @@ -81,7 +74,7 @@ def create_descriptor_file(client, a DataRow. ValueError: When the upload parameters are invalid """ - file_upload_thread_count = 20 + file_upload_thread_count = FILE_UPLOAD_THREAD_COUNT DataRow = Entity.DataRow AssetAttachment = Entity.AssetAttachment @@ -192,7 +185,7 @@ def validate_keys(item): raise InvalidAttributeError(DataRow, invalid_keys) return item - def formatLegacyConversationalData(item): + def format_legacy_conversational_data(item): messages = item.pop("conversationalData") version = item.pop("version", 1) type = item.pop("type", "application/vnd.labelbox.conversational") @@ -213,7 +206,7 @@ def formatLegacyConversationalData(item): return item def convert_item(data_row_item): - if isinstance(data_row_item, DataRowItemBase): + if isinstance(data_row_item, DataRowUpsertItem): item = data_row_item.payload else: item = data_row_item @@ -223,7 +216,7 @@ def convert_item(data_row_item): return item if "conversationalData" in item: - formatLegacyConversationalData(item) + format_legacy_conversational_data(item) # Convert all payload variations into the same dict format item = format_row(item) @@ -238,7 +231,7 @@ def convert_item(data_row_item): # Upload any local file paths item = upload_if_necessary(item) - if isinstance(data_row_item, DataRowItemBase): + if isinstance(data_row_item, DataRowUpsertItem): return {'id': data_row_item.id, 'payload': item} else: return item @@ -263,7 +256,7 @@ def convert_item(data_row_item): filename="json_import.json") @staticmethod - def upload_in_chunks(client, specs: List[DataRowItemBase], + def upload_in_chunks(client, specs: List[DataRowUpsertItem], file_upload_thread_count: int, upsert_chunk_size: int) -> UploadManifest: empty_specs = list(filter(lambda spec: spec.is_empty(), specs)) @@ -278,9 +271,9 @@ def upload_in_chunks(client, specs: List[DataRowItemBase], for i in range(0, len(specs), upsert_chunk_size) ] - def _upload_chunk(_chunk): + def _upload_chunk(chunk): return DataRowUploader.create_descriptor_file(client, - _chunk, + chunk, is_upsert=True) with ThreadPoolExecutor(file_upload_thread_count) as executor: diff --git a/libs/labelbox/src/labelbox/schema/internal/data_row_create_upsert.py b/libs/labelbox/src/labelbox/schema/internal/data_row_upsert_item.py similarity index 70% rename from libs/labelbox/src/labelbox/schema/internal/data_row_create_upsert.py rename to libs/labelbox/src/labelbox/schema/internal/data_row_upsert_item.py index 4f5b5086e..e2c0cb2b5 100644 --- a/libs/labelbox/src/labelbox/schema/internal/data_row_create_upsert.py +++ b/libs/labelbox/src/labelbox/schema/internal/data_row_upsert_item.py @@ -1,11 +1,10 @@ -from abc import ABC, abstractmethod from typing import List, Tuple, Optional from labelbox.schema.identifiable import UniqueId, GlobalKey from labelbox.pydantic_compat import BaseModel -class DataRowItemBase(BaseModel, ABC): +class DataRowUpsertItem(BaseModel): """ Base class for creating payloads for upsert operations. """ @@ -13,13 +12,12 @@ class DataRowItemBase(BaseModel, ABC): payload: dict @classmethod - @abstractmethod def build( cls, dataset_id: str, items: List[dict], key_types: Optional[Tuple[type, ...]] = () - ) -> List["DataRowItemBase"]: + ) -> List["DataRowUpsertItem"]: upload_items = [] for item in items: @@ -50,27 +48,3 @@ def is_empty(self) -> bool: """ return (not self.payload or len(self.payload.keys()) == 1 and "dataset_id" in self.payload) - - -class DataRowUpsertItem(DataRowItemBase): - - @classmethod - def build( - cls, - dataset_id: str, - items: List[dict], - key_types: Optional[Tuple[type, ...]] = () - ) -> List["DataRowItemBase"]: - return super().build(dataset_id, items, (UniqueId, GlobalKey)) - - -class DataRowCreateItem(DataRowItemBase): - - @classmethod - def build( - cls, - dataset_id: str, - items: List[dict], - key_types: Optional[Tuple[type, ...]] = () - ) -> List["DataRowItemBase"]: - return super().build(dataset_id, items, ()) diff --git a/libs/labelbox/src/labelbox/schema/task.py b/libs/labelbox/src/labelbox/schema/task.py index 6c7eaf019..1c7401952 100644 --- a/libs/labelbox/src/labelbox/schema/task.py +++ b/libs/labelbox/src/labelbox/schema/task.py @@ -233,7 +233,7 @@ class DataUpsertTask(Task): """ Task class for data row upsert operations """ - __max_donwload_size: Final = MAX_DATAROW_PER_API_OPERATION + MAX_DOWNLOAD_SIZE: Final = MAX_DATAROW_PER_API_OPERATION def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -389,7 +389,7 @@ def _results_as_list(self) -> Optional[List[Dict[str, Any]]]: for row in data: results.append(row) total_downloaded += 1 - if total_downloaded >= self.__max_donwload_size: + if total_downloaded >= self.MAX_DOWNLOAD_SIZE: break if len(results) == 0: @@ -405,7 +405,7 @@ def _errors_as_list(self) -> Optional[List[Dict[str, Any]]]: for row in data: errors.append(row) total_downloaded += 1 - if total_downloaded >= self.__max_donwload_size: + if total_downloaded >= self.MAX_DOWNLOAD_SIZE: break if len(errors) == 0: diff --git a/libs/labelbox/tests/integration/test_data_rows.py b/libs/labelbox/tests/integration/test_data_rows.py index c1212460f..d84ff3778 100644 --- a/libs/labelbox/tests/integration/test_data_rows.py +++ b/libs/labelbox/tests/integration/test_data_rows.py @@ -186,7 +186,8 @@ def test_data_row_bulk_creation(dataset, rand_gen, image_url): { "row_data": image_url }, - ]) + ], + file_upload_thread_count=2) task.wait_till_done() assert task.has_errors() is False assert task.status == "COMPLETE" diff --git a/libs/labelbox/tests/unit/test_data_row_upsert_data.py b/libs/labelbox/tests/unit/test_data_row_upsert_data.py index 93683cfca..9f6eb1400 100644 --- a/libs/labelbox/tests/unit/test_data_row_upsert_data.py +++ b/libs/labelbox/tests/unit/test_data_row_upsert_data.py @@ -1,6 +1,5 @@ import pytest -from labelbox.schema.internal.data_row_create_upsert import (DataRowUpsertItem, - DataRowCreateItem) +from labelbox.schema.internal.data_row_upsert_item import (DataRowUpsertItem) from labelbox.schema.identifiable import UniqueId, GlobalKey from labelbox.schema.asset_attachment import AttachmentType @@ -13,12 +12,15 @@ def data_row_create_items(): "row_data": "http://my_site.com/photos/img_01.jpg", "global_key": "global_key1", "external_id": "ex_id1", - "attachments": [ - {"type": AttachmentType.RAW_TEXT, "name": "att1", "value": "test1"} - ], - "metadata": [ - {"name": "tag", "value": "tag value"}, - ] + "attachments": [{ + "type": AttachmentType.RAW_TEXT, + "name": "att1", + "value": "test1" + }], + "metadata": [{ + "name": "tag", + "value": "tag value" + },] }, ] return dataset_id, items @@ -44,7 +46,7 @@ def test_data_row_upsert_items(data_row_create_items, data_row_update_items): dataset_id, create_items = data_row_create_items dataset_id, update_items = data_row_update_items items = create_items + update_items - result = DataRowUpsertItem.build(dataset_id, items) + result = DataRowUpsertItem.build(dataset_id, items, (UniqueId, GlobalKey)) assert len(result) == len(items) for item, res in zip(items, result): assert res.payload == item @@ -52,7 +54,7 @@ def test_data_row_upsert_items(data_row_create_items, data_row_update_items): def test_data_row_create_items(data_row_create_items): dataset_id, items = data_row_create_items - result = DataRowCreateItem.build(dataset_id, items) + result = DataRowUpsertItem.build(dataset_id, items) assert len(result) == len(items) for item, res in zip(items, result): assert res.payload == item @@ -61,4 +63,4 @@ def test_data_row_create_items(data_row_create_items): def test_data_row_create_items_not_updateable(data_row_update_items): dataset_id, items = data_row_update_items with pytest.raises(ValueError): - DataRowCreateItem.build(dataset_id, items) + DataRowUpsertItem.build(dataset_id, items, ()) From 50098fee0fddc6cc4d53e511177ebb0ded389a5b Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Thu, 30 May 2024 11:59:56 -0700 Subject: [PATCH 9/9] Remove result_all error_all --- libs/labelbox/src/labelbox/schema/task.py | 27 +++---------------- .../tests/integration/test_data_rows.py | 15 ++++------- 2 files changed, 9 insertions(+), 33 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/task.py b/libs/labelbox/src/labelbox/schema/task.py index 1c7401952..83dea3d1f 100644 --- a/libs/labelbox/src/labelbox/schema/task.py +++ b/libs/labelbox/src/labelbox/schema/task.py @@ -242,7 +242,8 @@ def __init__(self, *args, **kwargs): @property def result(self) -> Optional[List[Dict[str, Any]]]: # type: ignore """ - Fetches maximum 150K results. If you need to fetch more, use `result_all` property + Fetches all results. + Note, for large uploads (>150K data rows), it could take multiple minutes to complete """ if self.status == "FAILED": raise ValueError(f"Job failed. Errors : {self.errors}") @@ -251,7 +252,8 @@ def result(self) -> Optional[List[Dict[str, Any]]]: # type: ignore @property def errors(self) -> Optional[List[Dict[str, Any]]]: # type: ignore """ - Fetches maximum 150K errors. If you need to fetch more, use `errors_all` property + Fetches all errors. + Note, for large uploads / large number of errors (>150K), it could take multiple minutes to complete """ return self._errors_as_list() @@ -265,23 +267,6 @@ def failed_data_rows( # type: ignore self) -> Optional[List[Dict[str, Any]]]: return self.errors - @property - def result_all(self) -> PaginatedCollection: - """ - This method uses our standard PaginatedCollection and allow to fetch any number of results - See here for more https://docs.labelbox.com/reference/sdk-fundamental-concepts-1#iterate-over-paginatedcollection - """ - return self._download_results_paginated() - - @property - def errors_all(self) -> PaginatedCollection: - """ - This method uses our standard PaginatedCollection and allow to fetch any number of errors - See here for more https://docs.labelbox.com/reference/sdk-fundamental-concepts-1#iterate-over-paginatedcollection - """ - - return self._download_errors_paginated() - def _download_results_paginated(self) -> PaginatedCollection: page_size = DOWNLOAD_RESULT_PAGE_SIZE from_cursor = None @@ -389,8 +374,6 @@ def _results_as_list(self) -> Optional[List[Dict[str, Any]]]: for row in data: results.append(row) total_downloaded += 1 - if total_downloaded >= self.MAX_DOWNLOAD_SIZE: - break if len(results) == 0: return None @@ -405,8 +388,6 @@ def _errors_as_list(self) -> Optional[List[Dict[str, Any]]]: for row in data: errors.append(row) total_downloaded += 1 - if total_downloaded >= self.MAX_DOWNLOAD_SIZE: - break if len(errors) == 0: return None diff --git a/libs/labelbox/tests/integration/test_data_rows.py b/libs/labelbox/tests/integration/test_data_rows.py index d84ff3778..30a462f80 100644 --- a/libs/labelbox/tests/integration/test_data_rows.py +++ b/libs/labelbox/tests/integration/test_data_rows.py @@ -196,9 +196,6 @@ def test_data_row_bulk_creation(dataset, rand_gen, image_url): assert len(results) == 2 row_data = [result["row_data"] for result in results] assert row_data == [image_url, image_url] - results_all = task.result_all - row_data = [result["row_data"] for result in results_all] - assert row_data == [image_url, image_url] data_rows = list(dataset.data_rows()) assert len(data_rows) == 2 @@ -237,7 +234,7 @@ def test_data_row_bulk_creation_from_file(dataset, local_image_file, image_url): assert task.status == "COMPLETE" assert len(task.result) == 2 assert task.has_errors() is False - results = [r for r in task.result_all] + results = task.result row_data = [result["row_data"] for result in results] assert len(row_data) == 2 @@ -257,7 +254,7 @@ def test_data_row_bulk_creation_from_row_data_file_external_id( assert task.status == "COMPLETE" assert len(task.result) == 2 assert task.has_errors() is False - results = [r for r in task.result_all] + results = task.result row_data = [result["row_data"] for result in results] assert len(row_data) == 2 assert image_url in row_data @@ -276,7 +273,7 @@ def test_data_row_bulk_creation_from_row_data_file(dataset, rand_gen, assert task.status == "COMPLETE" assert len(task.result) == 2 assert task.has_errors() is False - results = [r for r in task.result_all] + results = task.result row_data = [result["row_data"] for result in results] assert len(row_data) == 2 @@ -980,12 +977,10 @@ def test_data_row_bulk_creation_with_same_global_keys(dataset, sample_image, assert task.created_data_rows[0]['external_id'] == sample_image assert task.created_data_rows[0]['global_key'] == global_key_1 - errors = task.errors_all - all_errors = [er for er in errors] - assert len(all_errors) == 1 + assert len(task.errors) == 1 assert task.has_errors() is True - all_results = [result for result in task.result_all] + all_results = task.result assert len(all_results) == 1