|
| 1 | +import json |
| 2 | +import os |
| 3 | +from concurrent.futures import ThreadPoolExecutor, as_completed |
| 4 | + |
| 5 | +from typing import Iterable, List |
| 6 | + |
| 7 | +from labelbox.exceptions import InvalidQueryError |
| 8 | +from labelbox.exceptions import InvalidAttributeError |
| 9 | +from labelbox.exceptions import MalformedQueryException |
| 10 | +from labelbox.orm.model import Entity |
| 11 | +from labelbox.orm.model import Field |
| 12 | +from labelbox.schema.embedding import EmbeddingVector |
| 13 | +from labelbox.pydantic_compat import BaseModel |
| 14 | +from labelbox.schema.internal.datarow_upload_constants import ( |
| 15 | + MAX_DATAROW_PER_API_OPERATION, FILE_UPLOAD_THREAD_COUNT) |
| 16 | +from labelbox.schema.internal.data_row_upsert_item import DataRowUpsertItem |
| 17 | + |
| 18 | + |
| 19 | +class UploadManifest(BaseModel): |
| 20 | + source: str |
| 21 | + item_count: int |
| 22 | + chunk_uris: List[str] |
| 23 | + |
| 24 | + |
| 25 | +class DataRowUploader: |
| 26 | + |
| 27 | + @staticmethod |
| 28 | + def create_descriptor_file(client, |
| 29 | + items, |
| 30 | + max_attachments_per_data_row=None, |
| 31 | + is_upsert=False): |
| 32 | + """ |
| 33 | + This function is shared by `Dataset.create_data_rows`, `Dataset.create_data_rows_sync` and `Dataset.update_data_rows`. |
| 34 | + It is used to prepare the input file. The user defined input is validated, processed, and json stringified. |
| 35 | + Finally the json data is uploaded to gcs and a uri is returned. This uri can be passed as a parameter to a mutation that uploads data rows |
| 36 | +
|
| 37 | + Each element in `items` can be either a `str` or a `dict`. If |
| 38 | + it is a `str`, then it is interpreted as a local file path. The file |
| 39 | + is uploaded to Labelbox and a DataRow referencing it is created. |
| 40 | +
|
| 41 | + If an item is a `dict`, then it could support one of the two following structures |
| 42 | + 1. For static imagery, video, and text it should map `DataRow` field names to values. |
| 43 | + At the minimum an `items` passed as a `dict` must contain a `row_data` key and value. |
| 44 | + If the value for row_data is a local file path and the path exists, |
| 45 | + then the local file will be uploaded to labelbox. |
| 46 | +
|
| 47 | + 2. For tiled imagery the dict must match the import structure specified in the link below |
| 48 | + https://docs.labelbox.com/data-model/en/index-en#tiled-imagery-import |
| 49 | +
|
| 50 | + >>> dataset.create_data_rows([ |
| 51 | + >>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"}, |
| 52 | + >>> {DataRow.row_data:"/path/to/file1.jpg"}, |
| 53 | + >>> "path/to/file2.jpg", |
| 54 | + >>> {DataRow.row_data: {"tileLayerUrl" : "http://", ...}} |
| 55 | + >>> {DataRow.row_data: {"type" : ..., 'version' : ..., 'messages' : [...]}} |
| 56 | + >>> ]) |
| 57 | +
|
| 58 | + For an example showing how to upload tiled data_rows see the following notebook: |
| 59 | + https://github.com/Labelbox/labelbox-python/blob/ms/develop/model_assisted_labeling/tiled_imagery_mal.ipynb |
| 60 | +
|
| 61 | + Args: |
| 62 | + items (iterable of (dict or str)): See above for details. |
| 63 | + max_attachments_per_data_row (Optional[int]): Param used during attachment validation to determine |
| 64 | + if the user has provided too many attachments. |
| 65 | +
|
| 66 | + Returns: |
| 67 | + uri (string): A reference to the uploaded json data. |
| 68 | +
|
| 69 | + Raises: |
| 70 | + InvalidQueryError: If the `items` parameter does not conform to |
| 71 | + the specification above or if the server did not accept the |
| 72 | + DataRow creation request (unknown reason). |
| 73 | + InvalidAttributeError: If there are fields in `items` not valid for |
| 74 | + a DataRow. |
| 75 | + ValueError: When the upload parameters are invalid |
| 76 | + """ |
| 77 | + file_upload_thread_count = FILE_UPLOAD_THREAD_COUNT |
| 78 | + DataRow = Entity.DataRow |
| 79 | + AssetAttachment = Entity.AssetAttachment |
| 80 | + |
| 81 | + def upload_if_necessary(item): |
| 82 | + if is_upsert and 'row_data' not in item: |
| 83 | + # When upserting, row_data is not required |
| 84 | + return item |
| 85 | + row_data = item['row_data'] |
| 86 | + if isinstance(row_data, str) and os.path.exists(row_data): |
| 87 | + item_url = client.upload_file(row_data) |
| 88 | + item['row_data'] = item_url |
| 89 | + if 'external_id' not in item: |
| 90 | + # Default `external_id` to local file name |
| 91 | + item['external_id'] = row_data |
| 92 | + return item |
| 93 | + |
| 94 | + def validate_attachments(item): |
| 95 | + attachments = item.get('attachments') |
| 96 | + if attachments: |
| 97 | + if isinstance(attachments, list): |
| 98 | + if max_attachments_per_data_row and len( |
| 99 | + attachments) > max_attachments_per_data_row: |
| 100 | + raise ValueError( |
| 101 | + f"Max attachments number of supported attachments per data row is {max_attachments_per_data_row}." |
| 102 | + f" Found {len(attachments)}. Condense multiple attachments into one with the HTML attachment type if necessary." |
| 103 | + ) |
| 104 | + for attachment in attachments: |
| 105 | + AssetAttachment.validate_attachment_json(attachment) |
| 106 | + else: |
| 107 | + raise ValueError( |
| 108 | + f"Attachments must be a list. Found {type(attachments)}" |
| 109 | + ) |
| 110 | + return attachments |
| 111 | + |
| 112 | + def validate_embeddings(item): |
| 113 | + embeddings = item.get("embeddings") |
| 114 | + if embeddings: |
| 115 | + item["embeddings"] = [ |
| 116 | + EmbeddingVector(**e).to_gql() for e in embeddings |
| 117 | + ] |
| 118 | + |
| 119 | + def validate_conversational_data(conversational_data: list) -> None: |
| 120 | + """ |
| 121 | + Checks each conversational message for keys expected as per https://docs.labelbox.com/reference/text-conversational#sample-conversational-json |
| 122 | +
|
| 123 | + Args: |
| 124 | + conversational_data (list): list of dictionaries. |
| 125 | + """ |
| 126 | + |
| 127 | + def check_message_keys(message): |
| 128 | + accepted_message_keys = set([ |
| 129 | + "messageId", "timestampUsec", "content", "user", "align", |
| 130 | + "canLabel" |
| 131 | + ]) |
| 132 | + for key in message.keys(): |
| 133 | + if not key in accepted_message_keys: |
| 134 | + raise KeyError( |
| 135 | + f"Invalid {key} key found! Accepted keys in messages list is {accepted_message_keys}" |
| 136 | + ) |
| 137 | + |
| 138 | + if conversational_data and not isinstance(conversational_data, |
| 139 | + list): |
| 140 | + raise ValueError( |
| 141 | + f"conversationalData must be a list. Found {type(conversational_data)}" |
| 142 | + ) |
| 143 | + |
| 144 | + [check_message_keys(message) for message in conversational_data] |
| 145 | + |
| 146 | + def parse_metadata_fields(item): |
| 147 | + metadata_fields = item.get('metadata_fields') |
| 148 | + if metadata_fields: |
| 149 | + mdo = client.get_data_row_metadata_ontology() |
| 150 | + item['metadata_fields'] = mdo.parse_upsert_metadata( |
| 151 | + metadata_fields) |
| 152 | + |
| 153 | + def format_row(item): |
| 154 | + # Formats user input into a consistent dict structure |
| 155 | + if isinstance(item, dict): |
| 156 | + # Convert fields to strings |
| 157 | + item = { |
| 158 | + key.name if isinstance(key, Field) else key: value |
| 159 | + for key, value in item.items() |
| 160 | + } |
| 161 | + elif isinstance(item, str): |
| 162 | + # The main advantage of using a string over a dict is that the user is specifying |
| 163 | + # that the file should exist locally. |
| 164 | + # That info is lost after this section so we should check for it here. |
| 165 | + if not os.path.exists(item): |
| 166 | + raise ValueError(f"Filepath {item} does not exist.") |
| 167 | + item = {"row_data": item, "external_id": item} |
| 168 | + return item |
| 169 | + |
| 170 | + def validate_keys(item): |
| 171 | + if not is_upsert and 'row_data' not in item: |
| 172 | + raise InvalidQueryError( |
| 173 | + "`row_data` missing when creating DataRow.") |
| 174 | + |
| 175 | + if isinstance(item.get('row_data'), |
| 176 | + str) and item.get('row_data').startswith("s3:/"): |
| 177 | + raise InvalidQueryError( |
| 178 | + "row_data: s3 assets must start with 'https'.") |
| 179 | + allowed_extra_fields = { |
| 180 | + 'attachments', 'media_type', 'dataset_id', 'embeddings' |
| 181 | + } |
| 182 | + invalid_keys = set(item) - {f.name for f in DataRow.fields() |
| 183 | + } - allowed_extra_fields |
| 184 | + if invalid_keys: |
| 185 | + raise InvalidAttributeError(DataRow, invalid_keys) |
| 186 | + return item |
| 187 | + |
| 188 | + def format_legacy_conversational_data(item): |
| 189 | + messages = item.pop("conversationalData") |
| 190 | + version = item.pop("version", 1) |
| 191 | + type = item.pop("type", "application/vnd.labelbox.conversational") |
| 192 | + if "externalId" in item: |
| 193 | + external_id = item.pop("externalId") |
| 194 | + item["external_id"] = external_id |
| 195 | + if "globalKey" in item: |
| 196 | + global_key = item.pop("globalKey") |
| 197 | + item["globalKey"] = global_key |
| 198 | + validate_conversational_data(messages) |
| 199 | + one_conversation = \ |
| 200 | + { |
| 201 | + "type": type, |
| 202 | + "version": version, |
| 203 | + "messages": messages |
| 204 | + } |
| 205 | + item["row_data"] = one_conversation |
| 206 | + return item |
| 207 | + |
| 208 | + def convert_item(data_row_item): |
| 209 | + if isinstance(data_row_item, DataRowUpsertItem): |
| 210 | + item = data_row_item.payload |
| 211 | + else: |
| 212 | + item = data_row_item |
| 213 | + |
| 214 | + if "tileLayerUrl" in item: |
| 215 | + validate_attachments(item) |
| 216 | + return item |
| 217 | + |
| 218 | + if "conversationalData" in item: |
| 219 | + format_legacy_conversational_data(item) |
| 220 | + |
| 221 | + # Convert all payload variations into the same dict format |
| 222 | + item = format_row(item) |
| 223 | + # Make sure required keys exist (and there are no extra keys) |
| 224 | + validate_keys(item) |
| 225 | + # Make sure attachments are valid |
| 226 | + validate_attachments(item) |
| 227 | + # Make sure embeddings are valid |
| 228 | + validate_embeddings(item) |
| 229 | + # Parse metadata fields if they exist |
| 230 | + parse_metadata_fields(item) |
| 231 | + # Upload any local file paths |
| 232 | + item = upload_if_necessary(item) |
| 233 | + |
| 234 | + if isinstance(data_row_item, DataRowUpsertItem): |
| 235 | + return {'id': data_row_item.id, 'payload': item} |
| 236 | + else: |
| 237 | + return item |
| 238 | + |
| 239 | + if not isinstance(items, Iterable): |
| 240 | + raise ValueError( |
| 241 | + f"Must pass an iterable to create_data_rows. Found {type(items)}" |
| 242 | + ) |
| 243 | + |
| 244 | + if len(items) > MAX_DATAROW_PER_API_OPERATION: |
| 245 | + raise MalformedQueryException( |
| 246 | + f"Cannot create more than {MAX_DATAROW_PER_API_OPERATION} DataRows per function call." |
| 247 | + ) |
| 248 | + |
| 249 | + with ThreadPoolExecutor(file_upload_thread_count) as executor: |
| 250 | + futures = [executor.submit(convert_item, item) for item in items] |
| 251 | + items = [future.result() for future in as_completed(futures)] |
| 252 | + # Prepare and upload the desciptor file |
| 253 | + data = json.dumps(items) |
| 254 | + return client.upload_data(data, |
| 255 | + content_type="application/json", |
| 256 | + filename="json_import.json") |
| 257 | + |
| 258 | + @staticmethod |
| 259 | + def upload_in_chunks(client, specs: List[DataRowUpsertItem], |
| 260 | + file_upload_thread_count: int, |
| 261 | + upsert_chunk_size: int) -> UploadManifest: |
| 262 | + empty_specs = list(filter(lambda spec: spec.is_empty(), specs)) |
| 263 | + |
| 264 | + if empty_specs: |
| 265 | + ids = list(map(lambda spec: spec.id.get("value"), empty_specs)) |
| 266 | + raise ValueError( |
| 267 | + f"The following items have an empty payload: {ids}") |
| 268 | + |
| 269 | + chunks = [ |
| 270 | + specs[i:i + upsert_chunk_size] |
| 271 | + for i in range(0, len(specs), upsert_chunk_size) |
| 272 | + ] |
| 273 | + |
| 274 | + def _upload_chunk(chunk): |
| 275 | + return DataRowUploader.create_descriptor_file(client, |
| 276 | + chunk, |
| 277 | + is_upsert=True) |
| 278 | + |
| 279 | + with ThreadPoolExecutor(file_upload_thread_count) as executor: |
| 280 | + futures = [ |
| 281 | + executor.submit(_upload_chunk, chunk) for chunk in chunks |
| 282 | + ] |
| 283 | + chunk_uris = [future.result() for future in as_completed(futures)] |
| 284 | + |
| 285 | + return UploadManifest(source="SDK", |
| 286 | + item_count=len(specs), |
| 287 | + chunk_uris=chunk_uris) |
0 commit comments