Skip to content

Commit 5459996

Browse files
authored
[PLT-43] Vb/create datarows chunking plt 43 (#1627)
2 parents 8bd5ce5 + 50098fe commit 5459996

File tree

12 files changed

+789
-401
lines changed

12 files changed

+789
-401
lines changed

libs/labelbox/src/labelbox/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def upload_data(self,
407407
}),
408408
"map": (None, json.dumps({"1": ["variables.file"]})),
409409
}
410+
410411
response = requests.post(
411412
self.endpoint,
412413
headers={"authorization": "Bearer %s" % self.api_key},

libs/labelbox/src/labelbox/schema/dataset.py

Lines changed: 88 additions & 346 deletions
Large diffs are not rendered by default.
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
import json
2+
import os
3+
from concurrent.futures import ThreadPoolExecutor, as_completed
4+
5+
from typing import Iterable, List
6+
7+
from labelbox.exceptions import InvalidQueryError
8+
from labelbox.exceptions import InvalidAttributeError
9+
from labelbox.exceptions import MalformedQueryException
10+
from labelbox.orm.model import Entity
11+
from labelbox.orm.model import Field
12+
from labelbox.schema.embedding import EmbeddingVector
13+
from labelbox.pydantic_compat import BaseModel
14+
from labelbox.schema.internal.datarow_upload_constants import (
15+
MAX_DATAROW_PER_API_OPERATION, FILE_UPLOAD_THREAD_COUNT)
16+
from labelbox.schema.internal.data_row_upsert_item import DataRowUpsertItem
17+
18+
19+
class UploadManifest(BaseModel):
20+
source: str
21+
item_count: int
22+
chunk_uris: List[str]
23+
24+
25+
class DataRowUploader:
26+
27+
@staticmethod
28+
def create_descriptor_file(client,
29+
items,
30+
max_attachments_per_data_row=None,
31+
is_upsert=False):
32+
"""
33+
This function is shared by `Dataset.create_data_rows`, `Dataset.create_data_rows_sync` and `Dataset.update_data_rows`.
34+
It is used to prepare the input file. The user defined input is validated, processed, and json stringified.
35+
Finally the json data is uploaded to gcs and a uri is returned. This uri can be passed as a parameter to a mutation that uploads data rows
36+
37+
Each element in `items` can be either a `str` or a `dict`. If
38+
it is a `str`, then it is interpreted as a local file path. The file
39+
is uploaded to Labelbox and a DataRow referencing it is created.
40+
41+
If an item is a `dict`, then it could support one of the two following structures
42+
1. For static imagery, video, and text it should map `DataRow` field names to values.
43+
At the minimum an `items` passed as a `dict` must contain a `row_data` key and value.
44+
If the value for row_data is a local file path and the path exists,
45+
then the local file will be uploaded to labelbox.
46+
47+
2. For tiled imagery the dict must match the import structure specified in the link below
48+
https://docs.labelbox.com/data-model/en/index-en#tiled-imagery-import
49+
50+
>>> dataset.create_data_rows([
51+
>>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"},
52+
>>> {DataRow.row_data:"/path/to/file1.jpg"},
53+
>>> "path/to/file2.jpg",
54+
>>> {DataRow.row_data: {"tileLayerUrl" : "http://", ...}}
55+
>>> {DataRow.row_data: {"type" : ..., 'version' : ..., 'messages' : [...]}}
56+
>>> ])
57+
58+
For an example showing how to upload tiled data_rows see the following notebook:
59+
https://github.com/Labelbox/labelbox-python/blob/ms/develop/model_assisted_labeling/tiled_imagery_mal.ipynb
60+
61+
Args:
62+
items (iterable of (dict or str)): See above for details.
63+
max_attachments_per_data_row (Optional[int]): Param used during attachment validation to determine
64+
if the user has provided too many attachments.
65+
66+
Returns:
67+
uri (string): A reference to the uploaded json data.
68+
69+
Raises:
70+
InvalidQueryError: If the `items` parameter does not conform to
71+
the specification above or if the server did not accept the
72+
DataRow creation request (unknown reason).
73+
InvalidAttributeError: If there are fields in `items` not valid for
74+
a DataRow.
75+
ValueError: When the upload parameters are invalid
76+
"""
77+
file_upload_thread_count = FILE_UPLOAD_THREAD_COUNT
78+
DataRow = Entity.DataRow
79+
AssetAttachment = Entity.AssetAttachment
80+
81+
def upload_if_necessary(item):
82+
if is_upsert and 'row_data' not in item:
83+
# When upserting, row_data is not required
84+
return item
85+
row_data = item['row_data']
86+
if isinstance(row_data, str) and os.path.exists(row_data):
87+
item_url = client.upload_file(row_data)
88+
item['row_data'] = item_url
89+
if 'external_id' not in item:
90+
# Default `external_id` to local file name
91+
item['external_id'] = row_data
92+
return item
93+
94+
def validate_attachments(item):
95+
attachments = item.get('attachments')
96+
if attachments:
97+
if isinstance(attachments, list):
98+
if max_attachments_per_data_row and len(
99+
attachments) > max_attachments_per_data_row:
100+
raise ValueError(
101+
f"Max attachments number of supported attachments per data row is {max_attachments_per_data_row}."
102+
f" Found {len(attachments)}. Condense multiple attachments into one with the HTML attachment type if necessary."
103+
)
104+
for attachment in attachments:
105+
AssetAttachment.validate_attachment_json(attachment)
106+
else:
107+
raise ValueError(
108+
f"Attachments must be a list. Found {type(attachments)}"
109+
)
110+
return attachments
111+
112+
def validate_embeddings(item):
113+
embeddings = item.get("embeddings")
114+
if embeddings:
115+
item["embeddings"] = [
116+
EmbeddingVector(**e).to_gql() for e in embeddings
117+
]
118+
119+
def validate_conversational_data(conversational_data: list) -> None:
120+
"""
121+
Checks each conversational message for keys expected as per https://docs.labelbox.com/reference/text-conversational#sample-conversational-json
122+
123+
Args:
124+
conversational_data (list): list of dictionaries.
125+
"""
126+
127+
def check_message_keys(message):
128+
accepted_message_keys = set([
129+
"messageId", "timestampUsec", "content", "user", "align",
130+
"canLabel"
131+
])
132+
for key in message.keys():
133+
if not key in accepted_message_keys:
134+
raise KeyError(
135+
f"Invalid {key} key found! Accepted keys in messages list is {accepted_message_keys}"
136+
)
137+
138+
if conversational_data and not isinstance(conversational_data,
139+
list):
140+
raise ValueError(
141+
f"conversationalData must be a list. Found {type(conversational_data)}"
142+
)
143+
144+
[check_message_keys(message) for message in conversational_data]
145+
146+
def parse_metadata_fields(item):
147+
metadata_fields = item.get('metadata_fields')
148+
if metadata_fields:
149+
mdo = client.get_data_row_metadata_ontology()
150+
item['metadata_fields'] = mdo.parse_upsert_metadata(
151+
metadata_fields)
152+
153+
def format_row(item):
154+
# Formats user input into a consistent dict structure
155+
if isinstance(item, dict):
156+
# Convert fields to strings
157+
item = {
158+
key.name if isinstance(key, Field) else key: value
159+
for key, value in item.items()
160+
}
161+
elif isinstance(item, str):
162+
# The main advantage of using a string over a dict is that the user is specifying
163+
# that the file should exist locally.
164+
# That info is lost after this section so we should check for it here.
165+
if not os.path.exists(item):
166+
raise ValueError(f"Filepath {item} does not exist.")
167+
item = {"row_data": item, "external_id": item}
168+
return item
169+
170+
def validate_keys(item):
171+
if not is_upsert and 'row_data' not in item:
172+
raise InvalidQueryError(
173+
"`row_data` missing when creating DataRow.")
174+
175+
if isinstance(item.get('row_data'),
176+
str) and item.get('row_data').startswith("s3:/"):
177+
raise InvalidQueryError(
178+
"row_data: s3 assets must start with 'https'.")
179+
allowed_extra_fields = {
180+
'attachments', 'media_type', 'dataset_id', 'embeddings'
181+
}
182+
invalid_keys = set(item) - {f.name for f in DataRow.fields()
183+
} - allowed_extra_fields
184+
if invalid_keys:
185+
raise InvalidAttributeError(DataRow, invalid_keys)
186+
return item
187+
188+
def format_legacy_conversational_data(item):
189+
messages = item.pop("conversationalData")
190+
version = item.pop("version", 1)
191+
type = item.pop("type", "application/vnd.labelbox.conversational")
192+
if "externalId" in item:
193+
external_id = item.pop("externalId")
194+
item["external_id"] = external_id
195+
if "globalKey" in item:
196+
global_key = item.pop("globalKey")
197+
item["globalKey"] = global_key
198+
validate_conversational_data(messages)
199+
one_conversation = \
200+
{
201+
"type": type,
202+
"version": version,
203+
"messages": messages
204+
}
205+
item["row_data"] = one_conversation
206+
return item
207+
208+
def convert_item(data_row_item):
209+
if isinstance(data_row_item, DataRowUpsertItem):
210+
item = data_row_item.payload
211+
else:
212+
item = data_row_item
213+
214+
if "tileLayerUrl" in item:
215+
validate_attachments(item)
216+
return item
217+
218+
if "conversationalData" in item:
219+
format_legacy_conversational_data(item)
220+
221+
# Convert all payload variations into the same dict format
222+
item = format_row(item)
223+
# Make sure required keys exist (and there are no extra keys)
224+
validate_keys(item)
225+
# Make sure attachments are valid
226+
validate_attachments(item)
227+
# Make sure embeddings are valid
228+
validate_embeddings(item)
229+
# Parse metadata fields if they exist
230+
parse_metadata_fields(item)
231+
# Upload any local file paths
232+
item = upload_if_necessary(item)
233+
234+
if isinstance(data_row_item, DataRowUpsertItem):
235+
return {'id': data_row_item.id, 'payload': item}
236+
else:
237+
return item
238+
239+
if not isinstance(items, Iterable):
240+
raise ValueError(
241+
f"Must pass an iterable to create_data_rows. Found {type(items)}"
242+
)
243+
244+
if len(items) > MAX_DATAROW_PER_API_OPERATION:
245+
raise MalformedQueryException(
246+
f"Cannot create more than {MAX_DATAROW_PER_API_OPERATION} DataRows per function call."
247+
)
248+
249+
with ThreadPoolExecutor(file_upload_thread_count) as executor:
250+
futures = [executor.submit(convert_item, item) for item in items]
251+
items = [future.result() for future in as_completed(futures)]
252+
# Prepare and upload the desciptor file
253+
data = json.dumps(items)
254+
return client.upload_data(data,
255+
content_type="application/json",
256+
filename="json_import.json")
257+
258+
@staticmethod
259+
def upload_in_chunks(client, specs: List[DataRowUpsertItem],
260+
file_upload_thread_count: int,
261+
upsert_chunk_size: int) -> UploadManifest:
262+
empty_specs = list(filter(lambda spec: spec.is_empty(), specs))
263+
264+
if empty_specs:
265+
ids = list(map(lambda spec: spec.id.get("value"), empty_specs))
266+
raise ValueError(
267+
f"The following items have an empty payload: {ids}")
268+
269+
chunks = [
270+
specs[i:i + upsert_chunk_size]
271+
for i in range(0, len(specs), upsert_chunk_size)
272+
]
273+
274+
def _upload_chunk(chunk):
275+
return DataRowUploader.create_descriptor_file(client,
276+
chunk,
277+
is_upsert=True)
278+
279+
with ThreadPoolExecutor(file_upload_thread_count) as executor:
280+
futures = [
281+
executor.submit(_upload_chunk, chunk) for chunk in chunks
282+
]
283+
chunk_uris = [future.result() for future in as_completed(futures)]
284+
285+
return UploadManifest(source="SDK",
286+
item_count=len(specs),
287+
chunk_uris=chunk_uris)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import List, Tuple, Optional
2+
3+
from labelbox.schema.identifiable import UniqueId, GlobalKey
4+
from labelbox.pydantic_compat import BaseModel
5+
6+
7+
class DataRowUpsertItem(BaseModel):
8+
"""
9+
Base class for creating payloads for upsert operations.
10+
"""
11+
id: dict
12+
payload: dict
13+
14+
@classmethod
15+
def build(
16+
cls,
17+
dataset_id: str,
18+
items: List[dict],
19+
key_types: Optional[Tuple[type, ...]] = ()
20+
) -> List["DataRowUpsertItem"]:
21+
upload_items = []
22+
23+
for item in items:
24+
# enforce current dataset's id for all specs
25+
item['dataset_id'] = dataset_id
26+
key = item.pop('key', None)
27+
if not key:
28+
key = {'type': 'AUTO', 'value': ''}
29+
elif isinstance(key, key_types): # type: ignore
30+
key = {'type': key.id_type.value, 'value': key.key}
31+
else:
32+
if not key_types:
33+
raise ValueError(
34+
f"Can not have a key for this item, got: {key}")
35+
raise ValueError(
36+
f"Key must be an instance of {', '.join([t.__name__ for t in key_types])}, got: {type(item['key']).__name__}"
37+
)
38+
item = {
39+
k: v for k, v in item.items() if v is not None
40+
} # remove None values
41+
upload_items.append(cls(payload=item, id=key))
42+
return upload_items
43+
44+
def is_empty(self) -> bool:
45+
"""
46+
The payload is considered empty if it's actually empty or the only key is `dataset_id`.
47+
:return: bool
48+
"""
49+
return (not self.payload or
50+
len(self.payload.keys()) == 1 and "dataset_id" in self.payload)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
MAX_DATAROW_PER_API_OPERATION = 150_000
2+
FILE_UPLOAD_THREAD_COUNT = 20
3+
UPSERT_CHUNK_SIZE = 10_000
4+
DOWNLOAD_RESULT_PAGE_SIZE = 5_000

0 commit comments

Comments
 (0)